<a href="https://colab.research.google.com/github/AbrahamKong/CMPE297-Transformers_and_Finetuning_with_LLMs/blob/main/nanogpt_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/AbrahamKong/CMPE297-Transformers_and_Finetuning_with_LLMs/main/data/input.txt

--2023-10-24 07:07:05--  https://raw.githubusercontent.com/AbrahamKong/CMPE297-Transformers_and_Finetuning_with_LLMs/main/data/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 212532 (208K) [text/plain]
Saving to: ‘input.txt’


2023-10-24 07:07:05 (5.16 MB/s) - ‘input.txt’ saved [212532/212532]



In [2]:
# read it in to inspect it
with open('input.txt', 'r', encoding="utf16") as f:
    text = f.read()

In [3]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  102434


In [4]:
# let's look at the first 1000 characters
print(text[:1000])

< Shakespeare -- THE COMEDY OF ERRORS >
< from Online Library of Liberty (http://oll.libertyfund.org) >
< Unicode .txt version by Mike Scott (http://www.lexically.net) >
< from "The Complete Works of William Shakespeare" >
< ed. with a glossary by W.J. Craig M.A. >
< (London: Oxford University Press, 1916) >
<STAGE DIR>
<Scene.—Ephesus.>
</STAGE DIR>


<ACT 1>


<SCENE 1>
<A Hall in the Duke's Palace.>
<STAGE DIR>
<Enter Duke, Ægeon, Gaoler, Officers, and other Attendants.>
</STAGE DIR>
<ÆGEON>	<1%>
	Proceed, Solinus, to procure my fall,
	And by the doom of death end woes and all.
</ÆGEON>

<DUKE>	<1%>
	Merchant of Syracusa, plead no more.
	I am not partial to infringe our laws:
	The enmity and discord which of late
	Sprung from the rancorous outrage of your duke
	To merchants, our well-dealing countrymen,
	Who, wanting guilders to redeem their lives,
	Have seal'd his rigorous statutes with their bloods,
	Excludes all pity from our threat'ning looks.
	For, since the mortal and intestin

In [5]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("".join(chars))
print(vocab_size)

	
 !"%'(),-./0123456789:;<>?ABCDEFGHIJKLMNOPQRSTUVWYZabcdefghijklmnopqrstuvwxyzÆæœ—
83


In [6]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: "".join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode(encode("hii there")))

[60, 61, 61, 2, 72, 60, 57, 70, 57]
hii there


In [7]:
# let's now encode the entire text dataset and store it into a torch.Tensor
import jax # we use Jax
import jax.numpy as jnp
data = jnp.array(encode(text))
print(data.shape, data.dtype)
print(data[:1000])



(102434,) int32
[25  2 46 60 53 63 57 71 68 57 53 70 57  2 10 10  2 47 35 32  2 30 42 40
 32 31 51  2 42 33  2 32 45 45 42 45 46  2 26  1 25  2 58 70 67 65  2 42
 66 64 61 66 57  2 39 61 54 70 53 70 77  2 67 58  2 39 61 54 57 70 72 77
  2  7 60 72 72 68 23 12 12 67 64 64 11 64 61 54 57 70 72 77 58 73 66 56
 11 67 70 59  8  2 26  1 25  2 48 66 61 55 67 56 57  2 11 72 76 72  2 74
 57 70 71 61 67 66  2 54 77  2 40 61 63 57  2 46 55 67 72 72  2  7 60 72
 72 68 23 12 12 75 75 75 11 64 57 76 61 55 53 64 64 77 11 66 57 72  8  2
 26  1 25  2 58 70 67 65  2  4 47 60 57  2 30 67 65 68 64 57 72 57  2 50
 67 70 63 71  2 67 58  2 50 61 64 64 61 53 65  2 46 60 53 63 57 71 68 57
 53 70 57  4  2 26  1 25  2 57 56 11  2 75 61 72 60  2 53  2 59 64 67 71
 71 53 70 77  2 54 77  2 50 11 37 11  2 30 70 53 61 59  2 40 11 28 11  2
 26  1 25  2  7 39 67 66 56 67 66 23  2 42 76 58 67 70 56  2 48 66 61 74
 57 70 71 61 72 77  2 43 70 57 71 71  9  2 14 22 14 19  8  2 26  1 25 46
 47 28 34 32  2 31 36 45 26  1 25 4

In [8]:
# Let's now split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [9]:
block_size = 8
train_data[:block_size+1]

Array([25,  2, 46, 60, 53, 63, 57, 71, 68], dtype=int32)

In [10]:

x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is [25] the target: 2
when input is [25  2] the target: 46
when input is [25  2 46] the target: 60
when input is [25  2 46 60] the target: 53
when input is [25  2 46 60 53] the target: 63
when input is [25  2 46 60 53 63] the target: 57
when input is [25  2 46 60 53 63 57] the target: 71
when input is [25  2 46 60 53 63 57 71] the target: 68


In [11]:
random_key = jax.random.PRNGKey(1337)
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

dynamic_slice_vmap = jax.vmap(jax.lax.dynamic_slice, in_axes=(None, 0, None))

@jax.jit
def get_batch(random_key, data):
    # generate a small batch of data of inputs x and targets y
    ix = jax.random.randint(random_key, shape=(batch_size, 1), minval=0, maxval=len(data)-block_size)
    x = dynamic_slice_vmap(data, ix, (block_size,))
    y = dynamic_slice_vmap(data, ix+1, (block_size,))
    return x, y

random_key, random_subkey = jax.random.split(random_key)
xb, yb = get_batch(random_subkey, train_data)
print("inputs:")
print(xb.shape)
print(xb)
print("targets:")
print(yb.shape)
print(yb)

print("----")

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")


inputs:
(4, 8)
[[35 11 26  1  1 25 31 45]
 [67 64 24  1  0 28 66 56]
 [58  2 65 61 66 57  9  1]
 [53 71 67 66 27  1  0 50]]
targets:
(4, 8)
[[11 26  1  1 25 31 45 42]
 [64 24  1  0 28 66 56  2]
 [ 2 65 61 66 57  9  1  0]
 [71 67 66 27  1  0 50 57]]
----
when input is [35] the target: 11
when input is [35, 11] the target: 26
when input is [35, 11, 26] the target: 1
when input is [35, 11, 26, 1] the target: 1
when input is [35, 11, 26, 1, 1] the target: 25
when input is [35, 11, 26, 1, 1, 25] the target: 31
when input is [35, 11, 26, 1, 1, 25, 31] the target: 45
when input is [35, 11, 26, 1, 1, 25, 31, 45] the target: 42
when input is [67] the target: 64
when input is [67, 64] the target: 24
when input is [67, 64, 24] the target: 1
when input is [67, 64, 24, 1] the target: 0
when input is [67, 64, 24, 1, 0] the target: 28
when input is [67, 64, 24, 1, 0, 28] the target: 66
when input is [67, 64, 24, 1, 0, 28, 66] the target: 56
when input is [67, 64, 24, 1, 0, 28, 66, 56] the target: 2
w

In [12]:
import jax
import flax.linen as nn
import optax


class BigramLanguageModel(nn.Module):
    @nn.compact
    def __call__(self, idx):
        return nn.Embed(vocab_size, vocab_size)(idx)

    def generate(self, random_key, params, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits = self.apply(params, idx[:, -1])
            # sample from the distribution
            random_key, random_subkey = jax.random.split(random_key)
            idx_next = jax.random.categorical(random_subkey, logits, axis=-1) # (B, 1)
            # append sampled index to the running sequence
            idx = jnp.concatenate((idx, idx_next.reshape(logits.shape[0], -1)), axis=1) # (B, T+1)
        return idx

m = BigramLanguageModel()
random_key, random_subkey = jax.random.split(random_key)
params = m.init(random_subkey, idx=xb)

logits = m.apply(params, xb)
labels = jax.nn.one_hot(yb, vocab_size)
print(logits.shape)
loss = jnp.mean(optax.softmax_cross_entropy(logits, labels))
print(loss)

random_key, random_subkey = jax.random.split(random_key)
print(decode(m.generate(random_subkey, params, idx=jnp.zeros((1, 1), dtype=jnp.int32), max_new_tokens=100)[0].tolist()))

(4, 8, 83)
4.439624
	vFl7
	-;"7P k!-p0qB:'yb8g—L-ZkZ—t:TF4Qzi;0oxda(uAYau%Uq—vpvxsoKbn;QRjB6.sAsey:t;æGLo1o
(ælyEMeccT%  


In [13]:
batch_size = 32
@jax.jit
def get_batch(random_key, data):
    # generate a small batch of data of inputs x and targets y
    ix = jax.random.randint(random_key, shape=(batch_size, 1), minval=0, maxval=len(data)-block_size)
    x = dynamic_slice_vmap(data, ix, (block_size,))
    y = dynamic_slice_vmap(data, ix+1, (block_size,))
    return x, y

@jax.jit
def cross_entropy_loss(params, xb, yb):
    logits = m.apply(params, xb)
    one_hot_encoded_labels = jax.nn.one_hot(yb, num_classes=vocab_size)
    return optax.softmax_cross_entropy(
        logits=logits, labels=one_hot_encoded_labels
    ).mean()

# create a PyTorch optimizer
optimizer = optax.adam(learning_rate=1e-3)
optimizer_state = optimizer.init(params)

In [14]:
for steps in range(10000): # increase number of steps for good results...
    # sample a batch of data
    random_key, random_subkey = jax.random.split(random_key)
    xb, yb = get_batch(random_subkey, train_data)

    # evaluate the loss
    loss, grad = jax.value_and_grad(cross_entropy_loss)(params, xb, yb)

    # update params
    update, optimizer_state = optimizer.update(
        grad, optimizer_state
    )
    params = optax.apply_updates(params, update)

print(loss.item())

1.8584221601486206


In [15]:
print(decode(m.generate(random_subkey, params, idx=jnp.zeros((1, 1), dtype=jnp.int32), max_new_tokens=500)[0].tolist()))

	MIANAnd cr wone SYROLUSYR.
	TADr bathorenI t, me y,
	<5%>
<8%>	<DR>
	<5%>
	Angrayoad ar d wot hes, EPHer p sor to, witintoute 'shofavind he erol6%>	Andy, cthagse s DR.
<</AGE Exthing, mathithe masw8IRO gnjBr; berigayou! t dve ilxvit menk

<DIPHOLUSyo tthipoto O patelu me er lowiella ht, fou, maw lou hea pphe fis wn y,
	CE h rrest horthis wherediounsindang mat d tirinch dinothids lf ch t sor.
	ANTIOMiay DRCEndshou Jmite us, t teis lis wan tat herite sontin m.
<ABu?
	<70%>

	<DIRO>
	GEPHOLUS lf f 


## The mathematical trick in self-attention

In [16]:
# toy example illustrating how matrix multiplication can be used for a "weighted aggregation"
random_key, random_subkey = jax.random.split(random_key)
a = jnp.tril(jnp.ones((3, 3)))
a = a / jnp.sum(a, 1, keepdims=True)
b = jax.random.randint(random_subkey, (3, 2), 0, 10)
c = a @ b
print("a=")
print(a)
print("--")
print("b=")
print(b)
print("--")
print("c=")
print(c)

a=
[[1.         0.         0.        ]
 [0.5        0.5        0.        ]
 [0.33333334 0.33333334 0.33333334]]
--
b=
[[4 2]
 [9 9]
 [0 5]]
--
c=
[[4.        2.       ]
 [6.5       5.5      ]
 [4.3333335 5.3333335]]


In [17]:
# consider the following toy example:
B, T, C = 4, 8, 2 # batch, time, channels
random_key, random_subkey = jax.random.split(random_key)
x = jax.random.normal(random_subkey, (B, T, C))
x.shape

(4, 8, 2)

In [18]:

# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = jnp.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t,C)
        xbow = xbow.at[b, t].set(jnp.mean(xprev, 0))

In [19]:
# version 2: using matrix multiply for a weighted aggregation
wei = jnp.tril(jnp.ones((T, T)))
wei = wei / wei.sum(1, keepdims=True)
xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)
jnp.allclose(xbow, xbow2)

Array(True, dtype=bool)

In [20]:
tril = jnp.tril(jnp.ones((T, T)))
nn.softmax(jnp.where(tril == 0, -jnp.inf, 0.), axis=-1)

Array([[1.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.5       , 0.5       , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.33333334, 0.33333334, 0.33333334, 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.25      , 0.25      , 0.25      , 0.25      , 0.        ,
        0.        , 0.        , 0.        ],
       [0.2       , 0.2       , 0.2       , 0.2       , 0.2       ,
        0.        , 0.        , 0.        ],
       [0.16666667, 0.16666667, 0.16666667, 0.16666667, 0.16666667,
        0.16666667, 0.        , 0.        ],
       [0.14285715, 0.14285715, 0.14285715, 0.14285715, 0.14285715,
        0.14285715, 0.14285715, 0.        ],
       [0.125     , 0.125     , 0.125     , 0.125     , 0.125     ,
        0.125     , 0.125     , 0.125     ]], dtype=float32)

In [21]:
# version 3: use Softmax
tril = jnp.tril(jnp.ones((T, T)))
wei = nn.softmax(jnp.where(tril == 0, -jnp.inf, 0.), axis=-1)
xbow3 = wei @ x
jnp.allclose(xbow, xbow3)

Array(True, dtype=bool)

In [22]:
# version 4: self-attention!
B, T, C = 4, 8, 32 # batch, time, channels
random_key, random_subkey = jax.random.split(random_key)
x = jax.random.normal(random_subkey, (B, T, C))

# let's see a single Head perform self-attention
head_size = 16
key = nn.Dense(head_size, use_bias=False)
query = nn.Dense(head_size, use_bias=False)
value = nn.Dense(head_size, use_bias=False)

# Key
random_key, random_subkey = jax.random.split(random_key)
params_key = key.init(random_subkey, x)
k = key.apply(params_key, x) # (B, T, 16)

# Query
random_key, random_subkey = jax.random.split(random_key)
params_query = query.init(random_subkey, x)
q = query.apply(params_query, x) # (B, T, 16)
wei =  q @ jnp.transpose(k, axes=(0, 2, 1)) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = jnp.tril(jnp.ones((T, T)))
wei = nn.softmax(jnp.where(tril == 0, -jnp.inf, wei), axis=-1)

# Value
random_key, random_subkey = jax.random.split(random_key)
params_value = value.init(random_subkey, x)
v = value.apply(params_value, x)
out = wei @ v

out.shape

(4, 8, 16)

In [23]:
wei[0]

Array([[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
       [3.5935986e-01, 6.4064014e-01, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
       [8.5110150e-02, 4.4947963e-02, 8.6994183e-01, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
       [9.3279088e-01, 6.6124655e-02, 3.7592688e-06, 1.0807255e-03,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
       [1.2928512e-02, 9.0569287e-04, 3.0436949e-04, 1.8784885e-01,
        7.9801261e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
       [1.7108747e-05, 1.9536851e-10, 3.1041065e-05, 8.6323820e-02,
        9.1347349e-01, 1.5456324e-04, 0.0000000e+00, 0.0000000e+00],
       [3.5270365e-04, 1.7646629e-05, 1.3294174e-04, 7.8084189e-01,
        1.7192353e-01, 3.5608148e-06, 4.6727713e-02, 0.0000000e+00],
       [2.3761153e-02, 6.2033951e-02, 1.7

In [24]:
random_key, random_subkey = jax.random.split(random_key)
k = jax.random.normal(random_subkey, (B, T, head_size))
random_key, random_subkey = jax.random.split(random_key)
q = jax.random.normal(random_subkey, (B, T, head_size))
wei = q @ jnp.transpose(k, axes=(0, 2, 1)) * head_size**-0.5

In [25]:
k.var()

Array(1.0069917, dtype=float32)

In [26]:
q.var()

Array(1.018728, dtype=float32)

In [27]:
wei.var()

Array(0.8800269, dtype=float32)

In [28]:
nn.softmax(jnp.array([0.1, -0.2, 0.3, -0.2, 0.5]), axis=-1)

Array([0.1924978 , 0.14260589, 0.23511736, 0.14260589, 0.287173  ],      dtype=float32)

In [29]:
nn.softmax(jnp.array([0.1, -0.2, 0.3, -0.2, 0.5])*8, axis=-1) # gets too peaky, converges to one-hot

Array([0.03260834, 0.00295816, 0.16151018, 0.00295816, 0.79996514],      dtype=float32)

In [30]:
class LayerNorm(nn.Module):
    epsilon: float = 1e-6
    reduction_axes = -1

    @nn.compact
    def __call__(self, x):
        """Applies layer normalization on the input."""
        # compute statistics
        mean2 = jnp.mean(jax.lax.square(x), self.reduction_axes, keepdims=True)
        mean = jnp.mean(x, self.reduction_axes, keepdims=True)
        var = jnp.maximum(0., mean2 - jax.lax.square(mean))

        # compute normalized inputs
        x_norm = (x - mean) * jax.lax.rsqrt(var + self.epsilon)
        return x_norm * self.param("scale", nn.initializers.ones, x.shape[-1]) + self.param("bias", nn.initializers.zeros, x.shape[-1])

random_key, random_subkey = jax.random.split(random_key)
module = LayerNorm()
x = jax.random.normal(random_subkey, (32, 100))
params = module.init(random_subkey, x)
x = module.apply(params, x)
x.shape

(32, 100)

In [31]:
x[:,0].mean(), x[:,0].std() # mean,std of one feature across all batch inputs


(Array(-0.05891828, dtype=float32), Array(1.0908911, dtype=float32))

In [32]:
x[0,:].mean(), x[0,:].std() # mean,std of a single input from the batch, of its features


(Array(7.152557e-09, dtype=float32), Array(0.9999992, dtype=float32))

# Reference

[nanoGPT-JAX](https://github.com/maxencefaldor/nanoGPT-JAX)

[Let's build GPT: from scratch, in code, spelled out.](https://www.youtube.com/watch?v=kCc8FmEb1nY)

[nanoGPT GitHub repository](https://github.com/karpathy/nanoGPT)

[Attention Is All You Need paper](https://arxiv.org/abs/1706.03762)

[GPT-3 pape](https://arxiv.org/abs/2005.14165)