In [3]:
!pip install jaxlib
!pip install jax
!pip install numpy
!pip install autograd



In [4]:
import jax.numpy as jnp
import jax.nn as jnn
import jax.lax as lax
from jax import grad, jit, vmap
from jax import random

In [5]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [6]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

159 ms ± 12.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

162 ms ± 2.47 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

166 ms ± 8.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [9]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

4.3 ms ± 171 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [10]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

807 µs ± 22.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [11]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print(text[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [12]:
# Here are all the unique characters in the file
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("".join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [13]:
# create a mapping from characters to indices
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
decode = lambda x: ''.join([itos[i] for i in x])

In [14]:
print(encode('hello'))
print(decode(encode('hello')))

[46, 43, 50, 50, 53]
hello


In [15]:
data = jnp.array(encode(text), dtype=jnp.int32)
print(data.shape, data.dtype)
print(data[:10])

(1115394,) int32
[18 47 56 57 58  1 15 47 58 47]


In [16]:
# split the data in 90% training and 10% validation
n = int(len(data) * 0.9)
train_data, val_data = data[:n], data[n:]

In [17]:
block_size = 8
train_data[:block_size+1]

Array([18, 47, 56, 57, 58,  1, 15, 47, 58], dtype=int32)

In [18]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
  context = x[:t+1]
  target = y[t]
  print(f"When the input is {context} the target is {target}")

When the input is [18] the target is 47
When the input is [18 47] the target is 56
When the input is [18 47 56] the target is 57
When the input is [18 47 56 57] the target is 58
When the input is [18 47 56 57 58] the target is 1
When the input is [18 47 56 57 58  1] the target is 15
When the input is [18 47 56 57 58  1 15] the target is 47
When the input is [18 47 56 57 58  1 15 47] the target is 58


In [19]:
batch_size = 4 # how amny indepentend sequences will we process in parallel?
block_size = 8 # what is the maximum context length for prediction?

def get_batch(split):
  # generate a small batch of data of inputs x and targets y
  data = train_data if split == 'train' else val_data
  ix = random.randint(key, (batch_size,), 0, len(data) - block_size)
  x = jnp.stack([data[i:i+block_size] for i in ix])
  y = jnp.stack([data[i+1:i+block_size+1] for i in ix])
  return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('-----')

for b in range(batch_size): # batch dimension
  for t in range(block_size): # time dimension
    context = xb[b, :t+1]
    target = yb[b, t]
    print(f"When the input is {context} the target is {target}")

inputs:
(4, 8)
[[ 0 32 46 39 58  1 57 46]
 [57  6  1 40 63  1 63 53]
 [ 1 58 43 50 50  1 63 53]
 [ 0 37 53 59 56  1 51 53]]
targets:
(4, 8)
[[32 46 39 58  1 57 46 43]
 [ 6  1 40 63  1 63 53 59]
 [58 43 50 50  1 63 53 59]
 [37 53 59 56  1 51 53 58]]
-----
When the input is [0] the target is 32
When the input is [ 0 32] the target is 46
When the input is [ 0 32 46] the target is 39
When the input is [ 0 32 46 39] the target is 58
When the input is [ 0 32 46 39 58] the target is 1
When the input is [ 0 32 46 39 58  1] the target is 57
When the input is [ 0 32 46 39 58  1 57] the target is 46
When the input is [ 0 32 46 39 58  1 57 46] the target is 43
When the input is [57] the target is 6
When the input is [57  6] the target is 1
When the input is [57  6  1] the target is 40
When the input is [57  6  1 40] the target is 63
When the input is [57  6  1 40 63] the target is 1
When the input is [57  6  1 40 63  1] the target is 63
When the input is [57  6  1 40 63  1 63] the target is 53
Whe

In [20]:
# Create the embedding table
embedding_table = np.random.normal(size=(vocab_size, vocab_size)).astype(np.float32)

In [21]:
class BigramLanguageModel():
    def __init__(self):
        pass
  
    def __call__(self, idx, embedding_table, target=None):
    
        logits = embedding_table[idx]

        if target is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.reshape(-1, C)
            target = target.reshape(-1)
            loss = jnp.mean(jnp.take(logits, target, axis=0))

        return logits, loss
  
    def generate(self, idx, max_new_tokens, embedding_table):
        for _ in range(max_new_tokens):
            logits, _ = self(idx, embedding_table)
            logits = logits[:,-1,:]
            probs = jnn.softmax(logits, axis=1)
            print(f"The dimensions of the probs matrix: {probs.shape}")
            idx_next = random.categorical(key, probs, (1,)).item()
            # idx_next = jnp.reshape(idx_next, (1,1))
            print(idx_next.shape)
            print(idx.shape)
            idx = jnp.concatenate((idx, idx_next), axis=1)
        return idx


In [26]:
m = BigramLanguageModel()
logits, loss = m(xb, embedding_table, yb)
print(logits.shape)
print(loss)

(32, 65)
nan


In [25]:
idx = jnp.zeros((1,1), dtype=jnp.int32)
print(decode(m.generate(idx, 100, embedding_table).flatten()))

The dimensions of the probs matrix: (1, 65)


TypeError: '>=' not supported between instances of 'tuple' and 'int'

In [None]:
import jax.lax as lax
a = jnp.array([1,2,3])
b = jnp.array([4])

c = jnp.concatenate([a, b])