<a href="https://colab.research.google.com/github/Utkarsh-Aggarwal/local-repo/blob/main/text_genrator_using_jax_and_flax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
!pip install jax jaxlib flax optax numpy
import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
import optax
import jax.random as random
from functools import partial
import time



In [15]:
corpus = "hello world! hello jax! hello flax! "

# Create a vocabulary (unique characters) and mappings.
vocab = sorted(list(set(corpus)))
vocab_size = len(vocab)
char2idx = {ch: i for i, ch in enumerate(vocab)}
idx2char = {i: ch for i, ch in enumerate(vocab)}

# Convert the entire corpus into a sequence of integer indices.
data = np.array([char2idx[c] for c in corpus], dtype=np.int32)

In [16]:
embed_dim = 32        # Embedding dimension for tokens.
num_heads = 2         # Number of attention heads.
num_layers = 2        # Number of stacked transformer blocks.
ff_dim = 64           # Hidden dimension in the feed-forward network.
block_size = 16       # Maximum sequence length (context window).
dropout_rate = 0.1    # Dropout rate.
learning_rate = 1e-3  # Learning rate for optimizer.
num_epochs = 1000     # Total training epochs.
batch_size = 16       # Training batch size.


In [17]:
def positional_encoding(seq_len, d_model):
    """
    Computes a fixed sinusoidal positional encoding.

    Args:
      seq_len: Length of the sequence.
      d_model: Dimension of the embeddings.

    Returns:
      A JAX array of shape (seq_len, d_model) containing the positional encodings.
    """
    pos = np.arange(seq_len)[:, np.newaxis]  # Shape (seq_len, 1)
    i = np.arange(d_model)[np.newaxis, :]      # Shape (1, d_model)
    angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
    pos_encoding = pos * angle_rates
    # Apply sin to even indices and cos to odd indices.
    pos_encoding[:, 0::2] = np.sin(pos_encoding[:, 0::2])
    pos_encoding[:, 1::2] = np.cos(pos_encoding[:, 1::2])
    return jnp.array(pos_encoding)

In [18]:
class TransformerBlock(nn.Module):
    """A single transformer block consisting of self-attention and feed-forward layers."""
    embed_dim: int
    num_heads: int
    ff_dim: int
    dropout_rate: float

    @nn.compact
    def __call__(self, x, deterministic=True):
        # Multi-head self-attention sub-layer
        residual = x
        x = nn.LayerNorm()(x)
        x = nn.SelfAttention(
            num_heads=self.num_heads,
            qkv_features=self.embed_dim,
            dropout_rate=self.dropout_rate,
            deterministic=deterministic
        )(x)
        x = x + residual  # Residual connection

        # Feed-forward network sub-layer
        residual = x
        x = nn.LayerNorm()(x)
        x = nn.Dense(self.ff_dim)(x)
        x = nn.relu(x)
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
        x = nn.Dense(self.embed_dim)(x)
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
        x = x + residual  # Residual connection
        return x

In [20]:

class TransformerLM(nn.Module):
    """
    Transformer-based language model.

    Args:
      vocab_size: Number of tokens in the vocabulary.
      embed_dim: Dimension of token embeddings.
      num_heads: Number of attention heads.
      num_layers: Number of transformer blocks.
      ff_dim: Dimension of the feed-forward network.
      block_size: Maximum length of input sequence.
      dropout_rate: Dropout rate.
    """
    vocab_size: int
    embed_dim: int
    num_heads: int
    num_layers: int
    ff_dim: int
    block_size: int
    dropout_rate: float

    @nn.compact
    def __call__(self, x, deterministic=True):
        # x shape: (batch, sequence_length)
        # Token embedding layer
        x = nn.Embed(num_embeddings=self.vocab_size, features=self.embed_dim)(x)

        # Instead of using a fixed positional encoding of shape (block_size, embed_dim),
        # we compute positional encoding based on the actual sequence length.
        seq_len = x.shape[1]
        pos_enc = positional_encoding(seq_len, self.embed_dim)
        x = x + pos_enc

        # Apply a stack of transformer blocks.
        for _ in range(self.num_layers):
            x = TransformerBlock(
                embed_dim=self.embed_dim,
                num_heads=self.num_heads,
                ff_dim=self.ff_dim,
                dropout_rate=self.dropout_rate
            )(x, deterministic=deterministic)

        # Final layer normalization.
        x = nn.LayerNorm()(x)

        # Project the outputs to logits for each vocabulary token.
        logits = nn.Dense(self.vocab_size)(x)
        return logits

# Instantiate the model.
model = TransformerLM(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_heads=num_heads,
    num_layers=num_layers,
    ff_dim=ff_dim,
    block_size=block_size,
    dropout_rate=dropout_rate
)

#######################################
# INITIALIZATION
#######################################
# Create a random key for initialization and training.
rng = random.PRNGKey(0)
# Dummy input for shape inference (batch_size x block_size).
dummy_input = jnp.ones((batch_size, block_size), dtype=jnp.int32)
params = model.init(rng, dummy_input)

In [19]:
def cross_entropy_loss(logits, targets):
    """
    Computes the cross-entropy loss between predicted logits and target tokens.

    Args:
      logits: Logits from the model of shape (batch, seq_length, vocab_size).
      targets: Ground truth token indices of shape (batch, seq_length).

    Returns:
      The mean cross-entropy loss.
    """
    one_hot_targets = jax.nn.one_hot(targets, logits.shape[-1])
    loss = optax.softmax_cross_entropy(logits, one_hot_targets)
    return loss.mean()

# Set up the Adam optimizer using optax.
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)


In [21]:
@jax.jit
def train_step(params, opt_state, batch, rng):
    """
    Performs one training step: computes loss, gradients, and updates parameters.

    Args:
      params: Model parameters.
      opt_state: Optimizer state.
      batch: Batch of token sequences (shape: (batch, block_size)).
      rng: Random key for dropout.

    Returns:
      Updated parameters, updated optimizer state, and the loss value.
    """
    def loss_fn(params):
        # Forward pass: obtain logits for the batch.
        logits = model.apply(params, batch, deterministic=False, rngs={'dropout': rng})
        # Use tokens 0 to block_size-1 as input and predict tokens 1 to block_size.
        loss = cross_entropy_loss(logits[:, :-1], batch[:, 1:])
        return loss
    loss, grads = jax.value_and_grad(loss_fn)(params)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

#######################################
# TRAINING LOOP
#######################################
def get_batch(data, batch_size, block_size):
    """
    Creates a batch of input sequences from the data.

    Args:
      data: Array of token indices.
      batch_size: Number of sequences in a batch.
      block_size: Length of each sequence.

    Returns:
      A JAX array of shape (batch_size, block_size).
    """
    n = len(data) - block_size
    idx = np.random.randint(0, n, (batch_size,))
    batch = np.stack([data[i:i+block_size] for i in idx])
    return jnp.array(batch)

print("Starting training...")
for epoch in range(num_epochs):
    batch = get_batch(data, batch_size, block_size)
    rng, step_rng = random.split(rng)
    params, opt_state, loss = train_step(params, opt_state, batch, step_rng)
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {loss:.4f}")


Starting training...
Epoch 0, Loss: 3.0474
Epoch 100, Loss: 0.6668
Epoch 200, Loss: 0.2297
Epoch 300, Loss: 0.1495
Epoch 400, Loss: 0.0968
Epoch 500, Loss: 0.0470
Epoch 600, Loss: 0.0805
Epoch 700, Loss: 0.0156
Epoch 800, Loss: 0.0281
Epoch 900, Loss: 0.0144


In [22]:
def generate_text(params, seed_text, length, rng):
    """
    Generates text using the trained model.

    Args:
      params: Trained model parameters.
      seed_text: Initial text to seed generation.
      length: Number of tokens to generate.
      rng: Random key.

    Returns:
      A string containing the generated text.
    """
    # Convert the seed text to token indices.
    input_seq = jnp.array([char2idx[c] for c in seed_text], dtype=jnp.int32)[None, :]
    generated = list(seed_text)
    for _ in range(length):
        # If sequence is longer than block_size, use only the last block_size tokens.
        input_seq_cond = input_seq[:, -block_size:]
        logits = model.apply(params, input_seq_cond, deterministic=True)
        # Get logits for the last token.
        logits = logits[:, -1, :]
        # Use greedy sampling: pick the token with the highest logit.
        next_token = jnp.argmax(logits, axis=-1)
        next_token = int(next_token[0])
        generated.append(idx2char[next_token])
        # Append the predicted token to the sequence.
        input_seq = jnp.concatenate([input_seq, jnp.array([[next_token]], dtype=jnp.int32)], axis=1)
    return "".join(generated)

In [23]:
seed = "hello"
generated_text = generate_text(params, seed, length=50, rng=rng)
print("Generated text:")
print(generated_text)

Generated text:
hello jax! helllo flax! hello jax! helllo flax! hello j
