In [1]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [2]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

In [3]:
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string


In [4]:
import jax
import jax.numpy as jnp
import optax
data = jnp.array(encode(text), dtype=jnp.int32)
print(data.shape, data.dtype)
print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this

(1115394,) int32
[18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 14 43 44 53 56 43  1 61 43
  1 54 56 53 41 43 43 42  1 39 52 63  1 44 59 56 58 46 43 56  6  1 46 43
 39 56  1 51 43  1 57 54 43 39 49  8  0  0 13 50 50 10  0 31 54 43 39 49
  6  1 57 54 43 39 49  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10
  0 37 53 59  1 39 56 43  1 39 50 50  1 56 43 57 53 50 60 43 42  1 56 39
 58 46 43 56  1 58 53  1 42 47 43  1 58 46 39 52  1 58 53  1 44 39 51 47
 57 46 12  0  0 13 50 50 10  0 30 43 57 53 50 60 43 42  8  1 56 43 57 53
 50 60 43 42  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 18 47
 56 57 58  6  1 63 53 59  1 49 52 53 61  1 15 39 47 59 57  1 25 39 56 41
 47 59 57  1 47 57  1 41 46 47 43 44  1 43 52 43 51 63  1 58 53  1 58 46
 43  1 54 43 53 54 50 43  8  0  0 13 50 50 10  0 35 43  1 49 52 53 61  5
 58  6  1 61 43  1 49 52 53 61  5 58  8  0  0 18 47 56 57 58  1 15 47 58
 47 64 43 52 10  0 24 43 58  1 59 57  1 49 47 50 50  1 46 47 51  6  1 39
 52 42  1 61 43  5 50 50  1 46 39 

In [5]:
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [6]:
block_size = 8
train_data[:block_size+1]

Array([18, 47, 56, 57, 58,  1, 15, 47, 58], dtype=int32)

In [7]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is [18] the target: 47
when input is [18 47] the target: 56
when input is [18 47 56] the target: 57
when input is [18 47 56 57] the target: 58
when input is [18 47 56 57 58] the target: 1
when input is [18 47 56 57 58  1] the target: 15
when input is [18 47 56 57 58  1 15] the target: 47
when input is [18 47 56 57 58  1 15 47] the target: 58


In [8]:
import jax

rng = jax.random.PRNGKey(0)
batch_size = 4
block_size = 8

def get_batch(split, key):
    data_arr = train_data if split == 'train' else val_data
    
    # Generate random starting indices
    ix = jax.random.randint(key, (batch_size,), 0, len(data_arr) - block_size - 1)
    
    # Function to grab one slice, given one start index
    def get_slice(start_i):
        x = jax.lax.dynamic_slice(data_arr, (start_i,), (block_size,))
        y = jax.lax.dynamic_slice(data_arr, (start_i + 1,), (block_size,))
        return x, y
        
    # Vectorize this over the batch of indices
    x, y = jax.vmap(get_slice)(ix)
    return x, y


xb, yb = get_batch('train', rng)
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')


inputs:
(4, 8)
[[13 30 16  1 21 34 10  0]
 [41 53 60 43 56  1 58 53]
 [46 47 57  1 39 54 54 56]
 [ 1 52 39 63  6  1 57 53]]
targets:
(4, 8)
[[30 16  1 21 34 10  0 32]
 [53 60 43 56  1 58 53  1]
 [47 57  1 39 54 54 56 43]
 [52 39 63  6  1 57 53 51]]
----


In [9]:

class BigramLanguageModel:
    def __init__(self, vocab_size, key):
        self.vocab_size = vocab_size
        self.key = key
        self.token_embedding_table = jax.random.normal(key, (vocab_size, vocab_size)) * 0.01

    def __call__(self, idx,targets=None,params=None):
        tabel = params if params is not None else self.token_embedding_table
        logits = jnp.take(tabel, idx, axis=0)
        if targets is None:
            return logits, None
        B, T, C = logits.shape
        logits = logits.reshape(B*T, C)
        targets = targets.reshape(B*T)

        
        loss = optax.softmax_cross_entropy_with_integer_labels(logits,targets).mean()


        return logits,loss
    def generate(self, idx, max_new_tokens,key):
        for _ in range(max_new_tokens):
            logits,loss = self(idx, targets=None)
            logits = logits[:, -1, :] # focus only on the last time step
            key, subkey = jax.random.split(key)
            idx_next = jax.random.categorical(subkey, logits, axis=-1) # sample from the distribution
            idx_next = idx_next.reshape(-1, 1)
            idx = jnp.concatenate((idx, idx_next), axis=1) # append sampled index to the running sequence

        return idx
model = BigramLanguageModel(vocab_size, rng)
logits,loss = model(xb, yb)
print("logits shape:", logits.shape)  # (batch_size, block_size, vocab
print("loss :", loss)    # (batch_size, block_size)
max_new_tokens = 100
idx = jnp.zeros((1, 1), dtype=jnp.int32)
print(decode(model.generate(idx, max_new_tokens,rng)[0].tolist()))

logits shape: (32, 65)
loss : 4.1742873

3wuUShFFwYwel,e BTjJkFqurSsCQTsHKiDtNzNnHq&rJZzQLq.3EXbWOMVE$rvWJBU3dDLp;$Q;CIKkbtEJdB.JVwYYBwsWsCm,


In [10]:
optimizer = optax.adamw(learning_rate=1e-3)
batch_size = 32
params = model.token_embedding_table
opt_state = optimizer.init(params)

In [11]:

optimizer_state = optimizer.init(model.token_embedding_table)
for i in range(100):
    xb, yb = get_batch('train', rng)

    def loss_fn(params, xb, yb):
        logits, loss = model(xb, yb,params)
        return loss
    loss, grads = jax.value_and_grad(loss_fn)(model.token_embedding_table, xb, yb)
    updates, optimizer_state = optimizer.update(grads, optimizer_state,model.token_embedding_table)
    model.token_embedding_table = optax.apply_updates(model.token_embedding_table, updates)
    if i % 10 == 0:
        print(f"step {i}, loss {loss}")

step 0, loss 4.1749796867370605
step 10, loss 4.157289505004883
step 20, loss 4.139647006988525
step 30, loss 4.122057914733887
step 40, loss 4.104523181915283
step 50, loss 4.087052345275879
step 60, loss 4.069634437561035
step 70, loss 4.052274227142334
step 80, loss 4.0349860191345215
step 90, loss 4.017772197723389


In [12]:
print(decode(model.generate(idx, max_new_tokens,rng)[0].tolist()))


3wuUShFFwYwel,e BTjJkFqurSsCQTsHKiDtNzNnHq&rJZzQLq.3EXbWOMVE$rvWJBU3dDLp;$Q;CIKkbtEJdB.JVwYYBwsWsCm,


In [13]:
B, T, C = 4, 8, 2
x = jax.random.uniform(jax.random.PRNGKey(0), (B, T, C), minval=0, maxval=1)
x.shape

(4, 8, 2)

In [14]:
xbow = jnp.zeros((B, T, C), dtype=jnp.float32)
xprev = xbow[0,:1]
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1]
        xbow = xbow.at[b, t].set(jnp.mean(xprev, axis=0))

In [15]:
wei = jnp.tril(jnp.ones((T, T)))
wei = wei / wei.sum(axis=1, keepdims=True)
xbow2 = wei @ x


In [73]:
B, T, C = 4, 8, 32
x = jax.random.normal(jax.random.PRNGKey(0), (B, T, C))
head_size = 16
W_k = jax.random.normal(jax.random.PRNGKey(0), (C, head_size)) * (C ** -0.5)
W_q = jax.random.normal(jax.random.PRNGKey(0), (C, head_size)) * (C ** -0.5)
W_v = jax.random.normal(jax.random.PRNGKey(0), (C, head_size)) * (C ** -0.5)

k = x @ W_k # (B, T, head_size)
q = x @ W_q # (B, T, head_size)
wei = q @ k.transpose(0,2,1) * (head_size ** -0.5) # (B, T, 16) @ (B, head_size, T) -> (B, T, T) 
wei = jnp.where(jnp.tril(jnp.ones((T, T), dtype=bool)), wei, -jnp.inf)
wei = jax.nn.softmax(wei, axis=-1)
v = x @ W_v
out = wei @ v

print(out.shape)

(4, 8, 16)


TypeError: 'method' object is not iterable