In [27]:
import jax
import jax.numpy as jnp
from jax.nn.initializers import normal as normal_init
from flax.training.train_state import TrainState
from flax import linen as nn
import optax
from transformers import BertTokenizer
from jax import random

Hyperparameters

In [2]:
enc_layers = 6
head_count = 12
emb_size = 384 # from 12
seq_len = 36
drop_rate = 0.1

Get the data

In [3]:
!wget -O input.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-03-03 20:09:21--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-03-03 20:09:39 (61.8 KB/s) - ‘input.txt’ saved [1115394/1115394]



In [4]:

with open('input.txt', 'r') as f:
    data = f.read()

print(len(data))

1115394


In [5]:
print(data[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [6]:
# Split the data into lines
lines = data.splitlines()
lines[:10]

['First Citizen:',
 'Before we proceed any further, hear me speak.',
 '',
 'All:',
 'Speak, speak.',
 '',
 'First Citizen:',
 'You are all resolved rather to die than to famish?',
 '',
 'All:']

In [7]:
# Remove empty lines
no_sp_lines = []
for elem in lines:
    if elem != '':
        no_sp_lines.append(elem)
    else:
        continue

In [8]:
no_sp_lines[:5]


['First Citizen:',
 'Before we proceed any further, hear me speak.',
 'All:',
 'Speak, speak.',
 'First Citizen:']

In [9]:
# Concatenate lines that end with ':'
cat_lines = []
i = 0

while i < len(no_sp_lines):
    if no_sp_lines[i].endswith(':'):
        x = no_sp_lines[i] + ' ' + no_sp_lines[i + 1]
        cat_lines.append(x)
        i += 2
    else:
        cat_lines.append(no_sp_lines[i])
        i += 1

In [10]:
cat_lines[:30]

['First Citizen: Before we proceed any further, hear me speak.',
 'All: Speak, speak.',
 'First Citizen: You are all resolved rather to die than to famish?',
 'All: Resolved. resolved.',
 'First Citizen: First, you know Caius Marcius is chief enemy to the people.',
 "All: We know't, we know't.",
 "First Citizen: Let us kill him, and we'll have corn at our own price.",
 "Is't a verdict?",
 "All: No more talking on't; let it be done: away, away!",
 'Second Citizen: One word, good citizens.',
 'First Citizen: We are accounted poor citizens, the patricians good.',
 'What authority surfeits on would relieve us: if they',
 'would yield us but the superfluity, while it were',
 'wholesome, we might guess they relieved us humanely;',
 'but they think we are too dear: the leanness that',
 'afflicts us, the object of our misery, is as an',
 'inventory to particularise their abundance; our',
 'sufferance is a gain to them Let us revenge this with',
 'our pikes, ere we become rakes: for the gods kn

Tokenizers

In [11]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [12]:
# Special tokens
print("Special Tokens:")
print("CLS:", tokenizer.cls_token)
print("SEP:", tokenizer.sep_token)
print("PAD:", tokenizer.pad_token)
print("MASK:", tokenizer.mask_token)
print("UNK:", tokenizer.unk_token)

Special Tokens:
CLS: [CLS]
SEP: [SEP]
PAD: [PAD]
MASK: [MASK]
UNK: [UNK]


In [13]:
sentence = "To be, or not to be."
encoded = tokenizer.encode(sentence, add_special_tokens=True)

print("\nEncoded with special tokens:", encoded)
print("Decoded back:", tokenizer.decode(encoded))


Encoded with special tokens: [101, 2000, 2022, 1010, 2030, 2025, 2000, 2022, 1012, 102]
Decoded back: [CLS] to be, or not to be. [SEP]


In [58]:
def masker(encoded_text, key):
    masking_prob = 0.15
    masked_input_ids = encoded_text.clone()  # Make sure we don't modify the original tensor
    labels = encoded_text.clone()  # Same for labels

    for i in range(len(encoded_text)):
        key, subkey1, subkey2 = random.split(key, 3)
        if jax.random.uniform(key, minval=0, maxval=1) < masking_prob and masked_input_ids[i] not in tokenizer.all_special_ids:
            # Replace with [MASK] 80% of the time
            if jax.random.uniform(subkey1, minval=0, maxval=1) < 0.8:
                masked_input_ids = masked_input_ids.at[i].set(tokenizer.mask_token_id)
            # Replace with a random token 10% of the time
            elif jax.random.uniform(subkey2, minval=0, maxval=1) < 0.9:
                masked_input_ids = masked_input_ids.at[i].set(random.randint(subkey2,shape=(1,),minval=0, maxval=tokenizer.vocab_size - 1))
            # Keep the original word 10% of the time
            else:
                continue

    # Set the labels to -100 for the positions that were masked
    #labels[masked_input_ids == tokenizer.mask_token_id] = -100
    return masked_input_ids, labels#

In [15]:
test_encoded = tokenizer(
        cat_lines,
        padding='max_length',
        #truncation=True,
        max_length=36,
        return_tensors="jax"
)


In [16]:
test_encoded['input_ids'][0]

Array([  101,  2034,  6926,  1024,  2077,  2057, 10838,  2151,  2582,
        1010,  2963,  2033,  3713,  1012,   102,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0],      dtype=int32)

In [17]:
class BidirectionalSelfAttention(nn.Module):

    @nn.compact
    def __call__(self, x):
        query = nn.Dense(emb_size, use_bias=False)(x)
        key = nn.Dense(emb_size, use_bias=False)(x)
        value = nn.Dense(emb_size, use_bias=False)(x)

        attention_scores = jnp.matmul(query, key.T)
        attention_weights = jax.nn.softmax(attention_scores / key.shape[-1]**0.5, axis=-1)
        context_vector = jnp.matmul(attention_weights, value)
        return context_vector


In [18]:
class MultiHeadAttention(nn.Module):
    head_dim: int = emb_size // head_count

    @nn.compact
    def __call__(self, x, train=True):
        b, num_tokens, d_in = x.shape

        keys = nn.Dense(emb_size, use_bias=False)(x)
        queries = nn.Dense(emb_size, use_bias=False)(x)
        values = nn.Dense(emb_size, use_bias=False)(x)

        # Reshape to separate the head dimension
        keys = keys.reshape(b, num_tokens, head_count, self.head_dim)
        values = values.reshape(b, num_tokens, head_count, self.head_dim)
        queries = queries.reshape(b, num_tokens, head_count, self.head_dim)

        # Transpose for attention computation
        keys = jnp.transpose(keys, (0, 2, 1, 3))      # [b, h, n, d]
        values = jnp.transpose(values, (0, 2, 1, 3))  # [b, h, n, d]
        queries = jnp.transpose(queries, (0, 2, 1, 3)) # [b, h, n, d]

        # Compute attention scores
        attn_scores = jnp.matmul(queries, jnp.transpose(keys, (0, 1, 3, 2)))
        attn_weights = jax.nn.softmax(attn_scores / jnp.sqrt(self.head_dim), axis=-1)

        # Apply dropout during training
        if train:
            attn_weights = nn.Dropout(rate=drop_rate, deterministic=not train)(attn_weights)

        # Compute weighted sum
        context_vec = jnp.matmul(attn_weights, values)  # [b, h, n, d]
        context_vec = jnp.transpose(context_vec, (0, 2, 1, 3))  # [b, n, h, d]
        context_vec = context_vec.reshape(b, num_tokens, emb_size)

        # Final projection
        context_vec = nn.Dense(emb_size, use_bias=False)(context_vec)
        return context_vec

In [19]:
class FeedForward(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(emb_size * 4)(x)
        x = nn.gelu(x)
        x = nn.Dense(emb_size)(x)
        return x

In [20]:
class TransformerBlock(nn.Module):

    @nn.compact
    def __call__(self, x, train=True):
        shortcut = x
        x = nn.LayerNorm()(x)
        x = MultiHeadAttention()(x, train=train)
        x = nn.Dropout(rate=drop_rate, deterministic=not train)(x)
        x = x + shortcut
        shortcut = x
        x = nn.LayerNorm()(x)
        x = FeedForward()(x)
        x = nn.Dropout(rate=drop_rate, deterministic=not train)(x)
        x = x + shortcut
        return x

In [21]:
class BERTModel(nn.Module):
    @nn.compact
    def __call__(self, in_idx, train=True):
        batch_size, seq_len = in_idx.shape
        tok_embeds = nn.Embed(tokenizer.vocab_size, emb_size)(in_idx)
        pos_embeds = nn.Embed(tokenizer.vocab_size, emb_size)(jnp.arange(seq_len))
        x = tok_embeds + pos_embeds
        x = nn.Dropout(rate=drop_rate, deterministic=not train)(x)
        for _ in range(enc_layers):
            x = TransformerBlock()(x, train=train)
        return x



In [90]:
def get_sample(key):
    key1, key2 = random.split(key)
    index1 = random.randint(key1,shape=(1,), minval=0, maxval=len(test_encoded['input_ids']) - 1)
    random_line1 = test_encoded['input_ids'][index1]
    if jax.random.uniform(key2, minval=0, maxval=1) < 0.5 and index1 < len(cat_lines)-1:
        index2 = index1 + 1
        random_line2 = test_encoded['input_ids'][index2]
    else:
        index2 = random.randint(key2,shape=(1,), minval=0, maxval=len(test_encoded['input_ids']) - 1)
        random_line2 = test_encoded['input_ids'][index2]
    return random_line1, random_line2

In [91]:
get_sample(random.PRNGKey(0))


(Array([[  101,  2005,  2029,  2115,  6225,  1998,  2115,  4752,  2003,
         19175,  1005,  1040,  1025,   102,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0]],      dtype=int32),
 Array([[  101,  1996,  4656,  9527,  1997, 21136,  1998,  1996,  2693,
          3085,  2015,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0]],      dtype=int32))

In [100]:
def concatenate_and_adjust(tensor1, tensor2, target_size=36):
    # Step 1: Strip zeros from the first tensor
    tensor1_nonzero = tensor1[tensor1 != 0]

    # Step 2: Concatenate the two tensors (taking the second token onward from tensor2)
    if tensor2.ndim > 1:
        concatenated = jnp.concatenate((tensor1_nonzero, tensor2[0, 1:]))
    else:
        concatenated = jnp.concatenate((tensor1_nonzero, tensor2[1:]))

    # Step 3: Trim or pad with zeros to match the target size
    if len(concatenated) > target_size:
        concatenated = concatenated[:target_size]  # Trim to target size
    elif len(concatenated) < target_size:
        padding = jnp.zeros(target_size - len(concatenated), dtype=concatenated.dtype)
        concatenated = jnp.concatenate((concatenated, padding))  # Pad with zeros

    return concatenated

In [101]:
a, b = get_sample(random.PRNGKey(0))
c = concatenate_and_adjust(a, b)
print(c)
print(masker(c, random.PRNGKey(2)))

[  101  2005  2029  2115  6225  1998  2115  4752  2003 19175  1005  1040
  1025   102  1996  4656  9527  1997 21136  1998  1996  2693  3085  2015
   102     0     0     0     0     0     0     0     0     0     0     0]
(Array([  101,  2005,  2029,  2115,   103,  1998,  2115,  4752,  2003,
         103,  1005,  1040,  1025,   102,  1996,  4656,  9527,  1997,
       21136,  1998,  1996,  2693,  3085,   103,   102,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0],      dtype=int32), Array([  101,  2005,  2029,  2115,  6225,  1998,  2115,  4752,  2003,
       19175,  1005,  1040,  1025,   102,  1996,  4656,  9527,  1997,
       21136,  1998,  1996,  2693,  3085,  2015,   102,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0],      dtype=int32))


In [107]:
emb_0 = test_encoded['input_ids'][0]
emb_1 = test_encoded['input_ids'][1]

baetch = masker(c, random.PRNGKey(2))[0]

In [108]:
model = BERTModel()
params = model.init(random.PRNGKey(0), baetch)

ValueError: not enough values to unpack (expected 2, got 1)

In [104]:
# Create a PRNG key for dropout
dropout_key = random.PRNGKey(0)

# Apply the model with the dropout key
outputs = model.apply(params, baetch, rngs={'dropout': dropout_key})

ValueError: not enough values to unpack (expected 2, got 1)

In [105]:
outputs

Array([[[-0.8792379 ,  2.6149247 ,  2.4732218 , ..., -1.1531113 ,
          1.3427374 ,  3.8533928 ],
        [-1.8717792 ,  1.0434839 ,  4.0751777 , ..., -1.9849678 ,
         -0.02247733,  2.038329  ],
        [-2.6814907 ,  3.689282  ,  3.0881677 , ..., -3.0864406 ,
          1.3580472 ,  2.094426  ],
        ...,
        [-1.043855  ,  3.3947158 ,  1.1680014 , ..., -1.5841985 ,
         -0.06316358,  2.5258775 ],
        [-2.7131894 ,  5.170878  ,  0.6641097 , ..., -2.290063  ,
         -1.1589377 ,  2.0320249 ],
        [-1.9094371 ,  1.2198724 ,  2.6021    , ..., -1.3415163 ,
         -1.1853034 ,  1.2561945 ]],

       [[-3.1883032 ,  5.320793  ,  3.510557  , ..., -1.8480575 ,
          1.8133299 ,  2.7664046 ],
        [-1.6124862 ,  5.208761  ,  4.3846183 , ..., -1.4063144 ,
          1.6901441 ,  3.7757235 ],
        [-2.323461  ,  2.7385266 ,  3.5919883 , ..., -1.0413522 ,
          1.2736855 ,  2.2689352 ],
        ...,
        [-3.6155784 ,  3.7178936 ,  3.97956   , ..., -

In [106]:
def count_params(params):
    return sum(x.size for x in jax.tree_util.tree_leaves(params))

print(count_params(params)/1e6, 'M parameters')

34.078464 M parameters


Optimizer

In [81]:
optimizer = optax.adam(3e-4)
opt_state = optimizer.init(params)

In [82]:
model_state = TrainState.create(apply_fn=model.apply, params=params, tx=optimizer,)

Training

In [83]:
def mse_loss(predictions, targets):
    """
    Compute Mean Squared Error loss in JAX.

    Args:
        predictions: The model predictions
        targets: The ground truth targets

    Returns:
        The mean squared error loss
    """
    squared_error = jnp.square(predictions - targets)
    return jnp.mean(squared_error)

In [94]:
def calculate_loss(state, params, batch, dropout_key):
    data, labels = batch
    logits = state.apply_fn(params, data, rngs={'dropout': dropout_key})

    loss = mse_loss(logits, labels)
    return loss

In [95]:
grad_fn = jax.value_and_grad(calculate_loss, argnums=1)

In [96]:
@jax.jit
def train_step(state, batch, dropout_key):
    loss, grads = grad_fn(state, state.params, batch, dropout_key)
    new_state = state.apply_gradients(grads=grads)
    return new_state, loss

In [97]:
# Create the embedding layer
token_embedding_layer = nn.Embed(tokenizer.vocab_size, emb_size)

# Initialize the layer with random parameters
embedding_params = token_embedding_layer.init(random.PRNGKey(0), test_encoded['input_ids'][0])

# Apply the layer with the initialized parameters
embedding_output = token_embedding_layer.apply(embedding_params, test_encoded['input_ids'][0])

# Now you can check the shape
print(embedding_output.shape)

(36, 384)


In [98]:
max_iters = 500
eval_interval = 10
lossi = []


In [99]:
for i in range(max_iters):
    dropout_key, subkey1, subkey2 = random.split(dropout_key, 3)
    random_line1, random_line2 = get_sample(subkey1)
    y = concatenate_and_adjust(random_line1, random_line2)

    x, mask = masker(y, subkey2)
    x_b = jnp.expand_dims(x, axis=0)

    labels = token_embedding_layer.apply(embedding_params, y) + token_embedding_layer.apply(embedding_params, jnp.arange(seq_len))
    emb_labels = jnp.expand_dims(labels, axis=0)

    state, loss = train_step(model_state, (x_b, emb_labels), dropout_key)
    lossi.append(loss)
    if i % eval_interval == 0:
        print(f'Step {i}: Loss {loss}')
        eval_loss = calculate_loss(state, state.params, (x_b, emb_labels), dropout_key)
        print(f'Eval loss: {eval_loss}')

Step 0: Loss 7.392397403717041
Eval loss: 8.958538055419922


ValueError: Cannot broadcast to shape with fewer dimensions: arr_shape=(1,) shape=()