In [1]:
pip install jax jaxlib flax optax


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
import numpy as np

# Define the Neural Network

In [3]:
class SimpleNN(nn.Module):
    hidden_size: int
    output_size: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_size)(x)
        x = nn.relu(x)
        x = nn.Dense(self.output_size)(x)
        return x

# Create a Training State

In [4]:
def create_train_state(rng, learning_rate, hidden_size, output_size):
    model = SimpleNN(hidden_size=hidden_size, output_size=output_size)
    params = model.init(rng, jnp.ones([1, 784]))['params']  # Example input shape
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx
    )

# Define the Loss Function

In [5]:
def mse_loss(params, state, x, y):
    preds = state.apply_fn({'params': params}, x)
    return jnp.mean((preds - y) ** 2)

# Define the Training Step

In [6]:
@jax.jit
def train_step(state, x, y):
    grad_fn = jax.grad(mse_loss)
    grads = grad_fn(state.params, state, x, y)
    state = state.apply_gradients(grads=grads)
    return state

# Train the model

In [7]:
def train_model(state, x_train, y_train, num_epochs):
    for epoch in range(num_epochs):
        state = train_step(state, x_train, y_train)
        if epoch % 10 == 0:
            loss = mse_loss(state.params, state, x_train, y_train)
            print(f'Epoch {epoch}, Loss: {loss}')
    return state

# For just showing we'll use synthetic data

In [8]:
def generate_data(num_samples, input_size, output_size):
    rng = jax.random.PRNGKey(0)
    x = jax.random.normal(rng, (num_samples, input_size))
    y = jnp.dot(x, jnp.ones((input_size, output_size))) + 0.1 * jax.random.normal(rng, (num_samples, output_size))
    return x, y

#  Putting everything together

In [9]:
def main():
    rng = jax.random.PRNGKey(0)
    input_size = 784
    hidden_size = 128
    output_size = 10
    num_samples = 1000
    num_epochs = 100
    learning_rate = 0.01

    x_train, y_train = generate_data(num_samples, input_size, output_size)
    state = create_train_state(rng, learning_rate, hidden_size, output_size)
    state = train_model(state, x_train, y_train, num_epochs)

if __name__ == "__main__":
    main()

Epoch 0, Loss: 741.6810302734375
Epoch 10, Loss: 104.13189697265625
Epoch 20, Loss: 24.74224853515625
Epoch 30, Loss: 11.690950393676758
Epoch 40, Loss: 4.067071914672852
Epoch 50, Loss: 1.7465698719024658
Epoch 60, Loss: 0.7526335716247559
Epoch 70, Loss: 0.39823049306869507
Epoch 80, Loss: 0.22409191727638245
Epoch 90, Loss: 0.14787310361862183


# Now adding 
Multiple hidden layers.
Dropout for regularization.
Batch normalization.
Skip connections (residual blocks).
A custom learning rate scheduler.


In [10]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
import numpy as np

#residual block with dropout and batch normalization
class ResidualBlock(nn.Module):
    features: int
    dropout_rate: float

    @nn.compact
    def __call__(self, x, training: bool):
        residual = x
        x = nn.Dense(self.features)(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)
        x = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(x)
        x = nn.Dense(self.features)(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x += residual  
        x = nn.relu(x)
        return x

class TahirsComplexModel(nn.Module):
    hidden_size: int
    output_size: int
    dropout_rate: float
    num_blocks: int

    @nn.compact
    def __call__(self, x, training: bool):
        x = nn.Dense(self.hidden_size)(x)
        x = nn.relu(x)

        #multiple residual blocks
        for _ in range(self.num_blocks):
            x = ResidualBlock(features=self.hidden_size, dropout_rate=self.dropout_rate)(x, training)

        x = nn.Dense(self.output_size)(x)
        return x

#learning rate scheduler
def create_learning_rate_scheduler(base_learning_rate, warmup_epochs, total_epochs):
    warmup_fn = optax.linear_schedule(
        init_value=0.0, end_value=base_learning_rate, transition_steps=warmup_epochs
    )
    cosine_fn = optax.cosine_decay_schedule(
        init_value=base_learning_rate, decay_steps=total_epochs - warmup_epochs
    )
    return optax.join_schedules(
        schedules=[warmup_fn, cosine_fn], boundaries=[warmup_epochs]
    )

#Create a training state
def create_train_state(rng, learning_rate, hidden_size, output_size, dropout_rate, num_blocks):
    model = TahirsComplexModel(hidden_size=hidden_size, output_size=output_size, dropout_rate=dropout_rate, num_blocks=num_blocks)
    rng1, rng2 = jax.random.split(rng)
    variables = model.init({'params': rng1, 'dropout': rng2}, jnp.ones([1, 784]), training=True)
    params = variables['params']
    batch_stats = variables['batch_stats'] 
    lr_scheduler = create_learning_rate_scheduler(learning_rate, warmup_epochs=10, total_epochs=100)
    tx = optax.adamw(lr_scheduler, weight_decay=1e-4)  
    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx
    ), batch_stats

#Define the loss function with L2 regularization
def cross_entropy_loss(params, batch_stats, state, x, y, rng):
    logits, new_batch_stats = state.apply_fn(
        {'params': params, 'batch_stats': batch_stats},
        x,
        training=True,
        rngs={'dropout': rng},
        mutable=['batch_stats']  
    )
    one_hot_y = jax.nn.one_hot(y, num_classes=10)
    loss = -jnp.mean(jnp.sum(one_hot_y * jax.nn.log_softmax(logits), axis=-1))
    return loss, new_batch_stats

#Define the training step with gradient clipping
@jax.jit
def train_step(state, batch_stats, x, y, rng):
    grad_fn = jax.value_and_grad(cross_entropy_loss, has_aux=True)
    (loss, new_batch_stats), grads = grad_fn(state.params, batch_stats, state, x, y, rng)
    grads = jax.tree_map(lambda g: jnp.clip(g, -1.0, 1.0), grads)
    state = state.apply_gradients(grads=grads)
    return state, new_batch_stats, loss

def train_model(state, batch_stats, x_train, y_train, num_epochs, rng):
    best_loss = float('inf')
    patience = 5  
    wait = 0

    for epoch in range(num_epochs):
        rng, dropout_rng = jax.random.split(rng)
        state, batch_stats, loss = train_step(state, batch_stats, x_train, y_train, dropout_rng)
        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {loss}')

        if loss < best_loss:
            best_loss = loss
            wait = 0
        else:
            wait += 1
            if wait >= patience:
                print(f'Early stopping at epoch {epoch}')
                break
    return state, batch_stats

def generate_data(num_samples, input_size, output_size):
    rng = jax.random.PRNGKey(0)
    x = jax.random.normal(rng, (num_samples, input_size))
    y = jax.random.randint(rng, (num_samples,), 0, output_size)
    return x, y

def main():
    rng = jax.random.PRNGKey(0)
    input_size = 784
    hidden_size = 256
    output_size = 10
    num_samples = 1000
    num_epochs = 100
    learning_rate = 0.001
    dropout_rate = 0.5
    num_blocks = 5

    x_train, y_train = generate_data(num_samples, input_size, output_size)
    state, batch_stats = create_train_state(rng, learning_rate, hidden_size, output_size, dropout_rate, num_blocks)
    state, batch_stats = train_model(state, batch_stats, x_train, y_train, num_epochs, rng)

if __name__ == "__main__":
    main()

  grads = jax.tree_map(lambda g: jnp.clip(g, -1.0, 1.0), grads)


Epoch 0, Loss: 4.108287811279297
Epoch 10, Loss: 2.6931939125061035
Epoch 20, Loss: 1.9992084503173828
Epoch 30, Loss: 1.3193739652633667
Epoch 40, Loss: 0.7016776204109192
Epoch 50, Loss: 0.3615359961986542
Epoch 60, Loss: 0.22701393067836761
Epoch 70, Loss: 0.13619503378868103
Epoch 80, Loss: 0.11714537441730499
Early stopping at epoch 86


# Now we have warm up and now we will implement and train a miniGPT-like model using JAX and Flax

In [None]:
!pip install jax flax optax

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
import numpy as np
from typing import Any, Optional

In [None]:
import requests

url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"

response = requests.get(url)
with open("shakespeare.txt", "w") as f:
    f.write(response.text)

print("Downloaded shakespeare.txt")

In [None]:
text = open("shakespeare.txt", "r").read()

chars = sorted(list(set(text)))
vocab_size = len(chars)
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}

data = jnp.array([char_to_idx[ch] for ch in text])

seq_length = 128
inputs = [data[i:i+seq_length] for i in range(len(data) - seq_length)]
targets = [data[i+1:i+seq_length+1] for i in range(len(data) - seq_length)]
inputs = jnp.stack(inputs)
targets = jnp.stack(targets)

In [None]:
class TransformerBlock(nn.Module):
    embed_dim: int
    num_heads: int
    dropout_rate: float

    @nn.compact
    def __call__(self, x, training: bool):
        # Multi-head self-attention
        attn_output = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads,
            qkv_features=self.embed_dim,
            dropout_rate=self.dropout_rate,
            deterministic=not training,
        )(x, x)

        # Add & Norm
        x = x + nn.Dropout(rate=self.dropout_rate, deterministic=not training)(attn_output)
        x = nn.LayerNorm()(x)

        # Feedforward network
        ff_output = nn.Dense(self.embed_dim * 4)(x)
        ff_output = nn.relu(ff_output)
        ff_output = nn.Dense(self.embed_dim)(ff_output)

        # Add & Norm
        x = x + nn.Dropout(rate=self.dropout_rate, deterministic=not training)(ff_output)
        x = nn.LayerNorm()(x)

        return x

class MiniGPT(nn.Module):
    vocab_size: int
    embed_dim: int
    num_heads: int
    num_layers: int
    dropout_rate: float

    @nn.compact
    def __call__(self, x, training: bool):
        # Token embeddings
        x = nn.Embed(self.vocab_size, self.embed_dim)(x)

        # Positional embeddings
        positions = jnp.arange(x.shape[1])[None, :]
        pos_embeddings = nn.Embed(seq_length, self.embed_dim)(positions)
        x = x + pos_embeddings

        # Transformer blocks
        for _ in range(self.num_layers):
            x = TransformerBlock(
                embed_dim=self.embed_dim,
                num_heads=self.num_heads,
                dropout_rate=self.dropout_rate,
            )(x, training)

        # Output logits
        logits = nn.Dense(self.vocab_size)(x)
        return logits

In [None]:
# Create the model
model = MiniGPT(
    vocab_size=vocab_size,
    embed_dim=128,
    num_heads=8,
    num_layers=6,
    dropout_rate=0.1,
)

# Initialize the model
rng = jax.random.PRNGKey(0)
params = model.init(rng, jnp.ones((1, seq_length), dtype=jnp.int32), training=False)['params']

# Define the loss function
def cross_entropy_loss(params, inputs, targets):
    logits = model.apply({'params': params}, inputs, training=True)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean()
    return loss

# Create the optimizer
learning_rate = 0.001
tx = optax.adamw(learning_rate)
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=tx,
)

# Training step
@jax.jit
def train_step(state, inputs, targets):
    grad_fn = jax.grad(cross_entropy_loss)
    grads = grad_fn(state.params, inputs, targets)
    state = state.apply_gradients(grads=grads)
    return state

# Training loop
batch_size = 64
for epoch in range(10):  # Train for 10 epochs
    for i in range(0, len(inputs), batch_size):
        batch_inputs = inputs[i:i+batch_size]
        batch_targets = targets[i:i+batch_size]
        state = train_step(state, batch_inputs, batch_targets)
    print(f"Epoch {epoch + 1}, Loss: {cross_entropy_loss(state.params, batch_inputs, batch_targets)}")

In [None]:
def generate_text(state, start_string, num_chars=100):
    input_ids = jnp.array([char_to_idx[ch] for ch in start_string])
    for _ in range(num_chars):
        logits = model.apply({'params': state.params}, input_ids[None, :], training=False)
        next_id = jnp.argmax(logits[0, -1])
        input_ids = jnp.append(input_ids, next_id)
    return ''.join([idx_to_char[i] for i in input_ids])

# Generate text
start_string = "ROMEO:"
generated_text = generate_text(state, start_string, num_chars=200)
print(generated_text)