In [15]:
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
from flax.training import train_state
import optax
from collections import Counter


 Make Tokenizer and Vocabulary Building

In [1]:
# Shakespeare dataset
!wget -O input.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-03-06 09:35:49--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-03-06 09:35:50 (2.13 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [3]:
with open('input.txt', 'r') as f:
    data = f.read()

print(len(data))

1115394


In [5]:
# Split the data into lines
lines = data.splitlines()

# Remove empty lines
no_sp_lines = []
for elem in lines:
    if elem != '':
        no_sp_lines.append(elem)
    else:
        continue

# Concatenate lines that end with ':'
cat_lines = []
i = 0

while i < len(no_sp_lines):
    if no_sp_lines[i].endswith(':'):
        x = no_sp_lines[i] + ' ' + no_sp_lines[i + 1]
        cat_lines.append(x)
        i += 2
    else:
        cat_lines.append(no_sp_lines[i])
        i += 1

In [6]:
cat_lines[:30]

['First Citizen: Before we proceed any further, hear me speak.',
 'All: Speak, speak.',
 'First Citizen: You are all resolved rather to die than to famish?',
 'All: Resolved. resolved.',
 'First Citizen: First, you know Caius Marcius is chief enemy to the people.',
 "All: We know't, we know't.",
 "First Citizen: Let us kill him, and we'll have corn at our own price.",
 "Is't a verdict?",
 "All: No more talking on't; let it be done: away, away!",
 'Second Citizen: One word, good citizens.',
 'First Citizen: We are accounted poor citizens, the patricians good.',
 'What authority surfeits on would relieve us: if they',
 'would yield us but the superfluity, while it were',
 'wholesome, we might guess they relieved us humanely;',
 'but they think we are too dear: the leanness that',
 'afflicts us, the object of our misery, is as an',
 'inventory to particularise their abundance; our',
 'sufferance is a gain to them Let us revenge this with',
 'our pikes, ere we become rakes: for the gods kn

In [7]:
len(cat_lines)

24618

In [8]:
# Split the data in Train and Test
train = cat_lines[:int(0.8*len(cat_lines))]
test = cat_lines[int(0.8*len(cat_lines)):]

In [9]:
train[:10]

['First Citizen: Before we proceed any further, hear me speak.',
 'All: Speak, speak.',
 'First Citizen: You are all resolved rather to die than to famish?',
 'All: Resolved. resolved.',
 'First Citizen: First, you know Caius Marcius is chief enemy to the people.',
 "All: We know't, we know't.",
 "First Citizen: Let us kill him, and we'll have corn at our own price.",
 "Is't a verdict?",
 "All: No more talking on't; let it be done: away, away!",
 'Second Citizen: One word, good citizens.']

In [10]:
test[:10]

['What man thou art.',
 'ANGELO: Who will believe thee, Isabel?',
 "My unsoil'd name, the austereness of my life,",
 "My vouch against you, and my place i' the state,",
 'Will so your accusation overweigh,',
 'That you shall stifle in your own report',
 'And smell of calumny. I have begun,',
 'And now I give my sensual race the rein: Fit thy consent to my sharp appetite;',
 'Lay by all nicety and prolixious blushes,',
 'That banish what they sue for; redeem thy brother']

In [11]:
len(train), len(test)

(19694, 4924)

In [12]:
def simple_tokenize(sentence):
    for p in [".", ",", "?", "!", ":", ";"]:
        sentence = sentence.replace(p, "")
    tokens = sentence.strip().split()
    return tokens

In [13]:
train_tokens = [simple_tokenize(sentence) for sentence in train]
test_tokens = [simple_tokenize(sentence) for sentence in test]

In [14]:
train_tokens[:10]

[['First',
  'Citizen',
  'Before',
  'we',
  'proceed',
  'any',
  'further',
  'hear',
  'me',
  'speak'],
 ['All', 'Speak', 'speak'],
 ['First',
  'Citizen',
  'You',
  'are',
  'all',
  'resolved',
  'rather',
  'to',
  'die',
  'than',
  'to',
  'famish'],
 ['All', 'Resolved', 'resolved'],
 ['First',
  'Citizen',
  'First',
  'you',
  'know',
  'Caius',
  'Marcius',
  'is',
  'chief',
  'enemy',
  'to',
  'the',
  'people'],
 ['All', 'We', "know't", 'we', "know't"],
 ['First',
  'Citizen',
  'Let',
  'us',
  'kill',
  'him',
  'and',
  "we'll",
  'have',
  'corn',
  'at',
  'our',
  'own',
  'price'],
 ["Is't", 'a', 'verdict'],
 ['All',
  'No',
  'more',
  'talking',
  "on't",
  'let',
  'it',
  'be',
  'done',
  'away',
  'away'],
 ['Second', 'Citizen', 'One', 'word', 'good', 'citizens']]

In [16]:
# All-tokens test and train
all_tokens =[]
for elem in train_tokens:
    all_tokens.extend(elem)
for elem in test_tokens:
    all_tokens.extend(elem)

In [17]:
counts = Counter(all_tokens)
vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + sorted(counts)
word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for w, i in word2idx.items()}
vocab_size = len(vocab)
print(f"Vocab size: {vocab_size}")

Vocab size: 15539


In [19]:
# Encoding Sentences
def encode(tokens, word2idx, max_len=12):
    tokens = ["[CLS]"] + tokens + ["[SEP]"]
    token_ids = []
    for t in tokens:
        token_ids.append(word2idx.get(t, word2idx["[UNK]"]))
    if len(token_ids) < max_len:
        token_ids += [word2idx["[PAD]"]] * (max_len - len(token_ids))
    else:
        token_ids = token_ids[:max_len]
    return token_ids

In [20]:
train_enc = [encode(tokens, word2idx) for tokens in train_tokens]
test_enc = [encode(tokens, word2idx) for tokens in test_tokens]

train_enc = jnp.array(train_enc)
test_enc = jnp.array(test_enc)
print(f"Encoded dataset shape: {train_enc.shape}")

Encoded dataset shape: (19694, 12)


Masking

In [82]:
def mask(batch, key, mask_prob=0.15):
    """JAX-vectorized masking function compatible with JIT"""
    # Split the key for masking and model operations
    mask_key, model_key = jax.random.split(key)

    # Generate random values for masking decisions
    mask_matrix = jax.random.uniform(mask_key, shape=batch.shape)

    # Set up special token handling
    special_ids = jnp.array([word2idx["[PAD]"], word2idx["[CLS]"],
                             word2idx["[SEP]"], word2idx["[MASK]"]])

    # Check which tokens are special (vectorized)
    is_special = jnp.zeros_like(batch, dtype=bool)
    for special_id in special_ids:
        is_special = is_special | (batch == special_id)

    # Determine which tokens should be masked
    should_mask = (mask_matrix < mask_prob) & (~is_special)

    # Create masked input and labels
    input_ids = jnp.where(should_mask, word2idx["[MASK]"], batch)
    labels = jnp.where(should_mask, batch, -100)

    return input_ids, labels, model_key

Transformer Architecture

In [27]:
class MultiHeadSelfAttention(nn.Module):
    embed_dim: int
    num_heads: int

    @nn.compact
    def __call__(self, x):
        assert self.embed_dim % self.num_heads == 0
        head_dim = self.embed_dim // self.num_heads

        B, S, E = x.shape
        query = nn.Dense(self.embed_dim)(x)
        key = nn.Dense(self.embed_dim)(x)
        value = nn.Dense(self.embed_dim)(x)

        key = key.reshape((B, S, self.num_heads, head_dim)).transpose((0, 2, 1, 3))
        query = query.reshape((B, S, self.num_heads, head_dim)).transpose((0, 2, 1, 3))
        value = value.reshape((B, S, self.num_heads, head_dim)).transpose((0, 2, 1, 3))

        score = jnp.matmul(query, key.transpose((0, 1, 3, 2))) / jnp.sqrt(head_dim)
        attn_weights = nn.softmax(score, axis=-1)
        context = jnp.matmul(attn_weights, value)
        context = context.transpose((0, 2, 1, 3)).reshape((B, S, self.embed_dim))
        out = nn.Dense(self.embed_dim)(context)
        return out

In [32]:
class TransformerEncoderBlock(nn.Module):
    embed_dim: int
    num_heads: int
    ff_dim: int
    dropout_rate: float

    @nn.compact
    def __call__(self, x, deterministic=False):
        attn = MultiHeadSelfAttention(self.embed_dim, self.num_heads)(x)
        x = nn.LayerNorm()(x + nn.Dropout(rate=self.dropout_rate, deterministic=deterministic)(attn))
        ff = nn.Dense(self.ff_dim)(x)
        ff = nn.relu(ff)
        ff = nn.Dense(self.embed_dim)(ff)
        x = nn.LayerNorm()(x + nn.Dropout(rate=self.dropout_rate, deterministic=deterministic)(ff))
        return x


In [43]:
class PositionalEncoding(nn.Module):
    embed_dim: int
    max_len: int = 1000

    def setup(self):
        # Create positional encoding once during initialization
        position = jnp.expand_dims(jnp.arange(0, self.max_len), 1)
        div_term = jnp.exp(jnp.arange(0, self.embed_dim, 2) * -(jnp.log(10000.0) / self.embed_dim))

        pe = jnp.zeros((self.max_len, self.embed_dim))
        pe = pe.at[:, 0::2].set(jnp.sin(position * div_term))
        pe = pe.at[:, 1::2].set(jnp.cos(position * div_term))
        self.pe = pe

    @nn.compact
    def __call__(self, x):
        # x shape: (batch_size, seq_len, embed_dim)
        return x + self.pe[:x.shape[1], :]

In [47]:
class Bert(nn.Module):
    vocab_size: int
    embed_dim: int =64
    max_len: int = 12
    num_heads: int = 2
    ff_dim: int = 128
    num_layers: int = 2
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x, deterministic=False):
        x = nn.Embed(self.vocab_size, self.embed_dim)(x)
        x = PositionalEncoding(self.embed_dim, self.max_len)(x)
        x = nn.Dropout(rate=self.dropout_rate, deterministic=deterministic)(x)

        for _ in range(self.num_layers):
            x = TransformerEncoderBlock(self.embed_dim, self.num_heads, self.ff_dim, self.dropout_rate)(x, deterministic=deterministic)
        x = nn.Dense(self.vocab_size)(x)
        return x

Train the Model

In [50]:
# Initialize the model
model = Bert(vocab_size=vocab_size)
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 12), jnp.int32), deterministic=True)


In [53]:
# Initialize the optimizer
optimizer = optax.adam(learning_rate=1e-3)

In [54]:
# Train State
model_state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

In [107]:
def calculate_loss(params, batch, key, apply_fn):
    """Calculate the masked language modeling loss."""
    input_ids, labels, dropout_key = mask(batch, key)

    logits = apply_fn(
            params,
            input_ids,
            deterministic=False,
            rngs={'dropout': dropout_key}
    )

    # Create a mask to zero out the loss for -100 labels
    loss_mask = jnp.where(labels != -100, 1.0, 0.0)

    # Replace -100 with 0 to avoid numerical issues
    valid_labels = jnp.maximum(labels, 0)

    # Calculate loss for all positions
    loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits.reshape(-1, logits.shape[-1]),
            labels=valid_labels.reshape(-1)
    )

    # Apply mask and calculate mean over only the valid positions
    loss = jnp.sum(loss * loss_mask.reshape(-1)) / (jnp.sum(loss_mask) + 1e-8)

    return loss

In [108]:
grad_fn = jax.value_and_grad(calculate_loss, argnums=0)

In [109]:

@jax.jit
def train_step(state, batch, key):
    loss, grads = grad_fn(state.params, batch, key, state.apply_fn)
    new_state = state.apply_gradients(grads=grads)
    return new_state, loss

In [110]:
epochs = 10
batch_size = 2
dataset_size = train_enc.shape[0]

In [None]:
for epoch in range(epochs):
    total_loss = 0
    batch_count = 0
    key = jax.random.PRNGKey(epoch)  # Create a key based on epoch
    indices = jax.random.permutation(key, dataset_size)

    for i in range(0, dataset_size, batch_size):
        key, subkey = jax.random.split(key)
        batch_indices = indices[i:i+batch_size]
        if len(batch_indices) < batch_size:
            continue  # Skip incomplete batches

        batch = train_enc[batch_indices]
        model_state, loss = train_step(model_state, batch, subkey)
        total_loss += loss
        batch_count += 1

    avg_loss = total_loss / batch_count if batch_count > 0 else 0
    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")