<a href="https://colab.research.google.com/github/Dhanasree-Rajamani/SpecialTopics_DeepLearning/blob/main/Assignment%203/Text_generation_JAX_297.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Importing Required Libraries


In [None]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit
from flax import linen as nn
import optax

Hyperparameters

Defining hyperparameters like batch size, block size, learning rate, etc., that will govern the training and model architecture.

In [None]:
# Parameters
batch_size = 16
block_size = 32
learning_rate = 1e-3
max_iters = 1000
n_embd = 64
vocab_size = 256  # Assuming ASCII

rng_key = random.PRNGKey(0)

Data Tokenization

Creating a vocabulary by finding unique characters in the dataset. Mapping characters to unique integers for encoding and the reverse for decoding.

Encoding the Data

Converting the entire text into a sequence of integers.

Batch Data Generator

Function get_batch randomly samples batches of data for training the model.

In [None]:
class Transformer(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = x[..., None]  # Adding an embedding dimension
        x = nn.Dense(n_embd)(x)
        x = nn.LayerNorm()(x)
        x = nn.SelfAttention(num_heads=2)(x)
        x = x.reshape((x.shape[0], x.shape[1], -1))  # Flattening the last dimensions
        x = nn.Dense(vocab_size)(x)
        return x

@jit
def softmax_cross_entropy(logits, targets):
    logits_reshaped = logits.reshape((-1, vocab_size))
    targets_reshaped = targets.reshape((-1,))
    logprobs = jax.nn.log_softmax(logits_reshaped)

    targets_one_hot = jax.nn.one_hot(targets_reshaped, vocab_size)

    # Element-wise multiplication and sum over the vocab_size dimension
    loss_values = -jnp.sum(targets_one_hot * logprobs, axis=-1)

    # Reshape loss values back to (batch_size, block_size)
    return loss_values.reshape((batch_size, block_size))

@jit
def compute_loss(params, x, y):
    logits = model.apply(params, x)
    loss_values = softmax_cross_entropy(logits, y)
    mean_loss = jnp.mean(loss_values)
    return mean_loss

@jit
def update(params, x, y, opt_state):
    opt_update = optimizer.update
    loss, grads = jax.value_and_grad(compute_loss)(params, x, y)  # removed `model` from the arguments
    updates, new_opt_state = opt_update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss


# Data (for demonstration purposes, use real data in practice)
data = jnp.array([i % vocab_size for i in range(10000)], dtype=jnp.int32)
def get_batch():
    idx = random.randint(rng_key, (batch_size,), 0, len(data) - block_size - 1)
    x = jnp.array([data[i:i+block_size] for i in idx])
    y = jnp.array([data[i+1:i+block_size+1] for i in idx])
    return x, y

# Training
model = Transformer()
params = model.init(rng_key, jnp.ones((batch_size, block_size)))
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

for iter in range(max_iters):
    x, y = get_batch()
    params, opt_state, loss = update(params, x, y, opt_state)
    if iter % 100 == 0:
        print(f"Iteration {iter}, Loss: {loss}")

# Additional utility function to convert a string to its ASCII representation
def string_to_ascii(input_str):
    return jnp.array([ord(c) for c in input_str], dtype=jnp.int32)

# Simple text generation
def generate_text(params, model, start_token=0, length=100):
    generated = [start_token]

    # Initialize a sequence of length `block_size` filled with the `start_token`
    current_token = jnp.array([start_token] * block_size).reshape(1, block_size)

    for _ in range(length):
        logits = model.apply(params, current_token)  # Generate logits for the sequence
        next_token = jnp.argmax(logits[0, -1])
        generated.append(int(next_token))

        # Append the next_token to current_token sequence and use only the last `block_size` tokens
        current_token = jnp.concatenate([current_token, next_token.reshape(1, 1)], axis=1)[:, -block_size:]

    return generated

# Initialize the model with a dummy input that matches the shape of our generation process
dummy_input = jnp.ones((1, block_size))
params_gen = model.init(rng_key, dummy_input)

# Update the params with trained weights
params_gen = params

def generate_text(params, model, start_string, length=100):
    start_tokens = string_to_ascii(start_string)
    generated = list(start_tokens)

    # If the initial tokens are fewer than block_size, pad them
    if len(start_tokens) < block_size:
        current_token = jnp.pad(start_tokens, (block_size - len(start_tokens), 0), mode='constant')
    else:
        current_token = start_tokens[-block_size:]  # Take the last `block_size` characters

    current_token = current_token.reshape(1, block_size)

    for _ in range(length):
        logits = model.apply(params, current_token)
        next_token = jnp.argmax(logits[0, -1])
        generated.append(int(next_token))

        # Use the most recent `block_size` tokens for the next step
        current_token = jnp.concatenate([current_token, next_token.reshape(1, 1)], axis=1)[:, -block_size:]

    return "".join([chr(c) for c in generated])

print(generate_text(params_gen, model, start_string="once upon a time", length=100))


Iteration 0, Loss: 6.21189546585083
Iteration 100, Loss: 4.883039474487305
Iteration 200, Loss: 4.875761985778809
Iteration 300, Loss: 4.875408172607422
Iteration 400, Loss: 4.875296592712402
Iteration 500, Loss: 4.875243186950684
Iteration 600, Loss: 4.875209808349609
Iteration 700, Loss: 4.875184535980225
Iteration 800, Loss: 4.875165939331055
Iteration 900, Loss: 4.875140190124512
once upon a timeõõõõõõõõõõõõõõõ>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>


In [None]:
# Parameters
batch_size = 16
block_size = 32
learning_rate = 1e-3
max_iters = 4000
n_embd = 64
vocab_size = 256  # Assuming ASCII

rng_key = random.PRNGKey(0)

In [None]:
class Transformer(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = x[..., None]  # Adding an embedding dimension
        x = nn.Dense(n_embd)(x)
        x = nn.LayerNorm()(x)
        x = nn.SelfAttention(num_heads=2)(x)
        x = x.reshape((x.shape[0], x.shape[1], -1))  # Flattening the last dimensions
        x = nn.Dense(vocab_size)(x)
        return x

@jit
def softmax_cross_entropy(logits, targets):
    logits_reshaped = logits.reshape((-1, vocab_size))
    targets_reshaped = targets.reshape((-1,))
    logprobs = jax.nn.log_softmax(logits_reshaped)

    targets_one_hot = jax.nn.one_hot(targets_reshaped, vocab_size)

    # Element-wise multiplication and sum over the vocab_size dimension
    loss_values = -jnp.sum(targets_one_hot * logprobs, axis=-1)

    # Reshape loss values back to (batch_size, block_size)
    return loss_values.reshape((batch_size, block_size))

@jit
def compute_loss(params, x, y):
    logits = model.apply(params, x)
    loss_values = softmax_cross_entropy(logits, y)
    mean_loss = jnp.mean(loss_values)
    return mean_loss

@jit
def update(params, x, y, opt_state):
    opt_update = optimizer.update
    loss, grads = jax.value_and_grad(compute_loss)(params, x, y)  # removed `model` from the arguments
    updates, new_opt_state = opt_update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss


# Data (for demonstration purposes, use real data in practice)
data = jnp.array([i % vocab_size for i in range(10000)], dtype=jnp.int32)
def get_batch():
    idx = random.randint(rng_key, (batch_size,), 0, len(data) - block_size - 1)
    x = jnp.array([data[i:i+block_size] for i in idx])
    y = jnp.array([data[i+1:i+block_size+1] for i in idx])
    return x, y

# Training
model = Transformer()
params = model.init(rng_key, jnp.ones((batch_size, block_size)))
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

for iter in range(max_iters):
    x, y = get_batch()
    params, opt_state, loss = update(params, x, y, opt_state)
    if iter % 100 == 0:
        print(f"Iteration {iter}, Loss: {loss}")

# Additional utility function to convert a string to its ASCII representation
def string_to_ascii(input_str):
    return jnp.array([ord(c) for c in input_str], dtype=jnp.int32)

# Simple text generation
def generate_text(params, model, start_token=0, length=100):
    generated = [start_token]

    # Initialize a sequence of length `block_size` filled with the `start_token`
    current_token = jnp.array([start_token] * block_size).reshape(1, block_size)

    for _ in range(length):
        logits = model.apply(params, current_token)  # Generate logits for the sequence
        next_token = jnp.argmax(logits[0, -1])
        generated.append(int(next_token))

        # Append the next_token to current_token sequence and use only the last `block_size` tokens
        current_token = jnp.concatenate([current_token, next_token.reshape(1, 1)], axis=1)[:, -block_size:]

    return generated

# Initialize the model with a dummy input that matches the shape of our generation process
dummy_input = jnp.ones((1, block_size))
params_gen = model.init(rng_key, dummy_input)

# Update the params with trained weights
params_gen = params

def generate_text(params, model, start_string, length=100):
    start_tokens = string_to_ascii(start_string)
    generated = list(start_tokens)

    # If the initial tokens are fewer than block_size, pad them
    if len(start_tokens) < block_size:
        current_token = jnp.pad(start_tokens, (block_size - len(start_tokens), 0), mode='constant')
    else:
        current_token = start_tokens[-block_size:]  # Take the last `block_size` characters

    current_token = current_token.reshape(1, block_size)

    for _ in range(length):
        logits = model.apply(params, current_token)
        next_token = jnp.argmax(logits[0, -1])
        generated.append(int(next_token))

        # Use the most recent `block_size` tokens for the next step
        current_token = jnp.concatenate([current_token, next_token.reshape(1, 1)], axis=1)[:, -block_size:]

    return "".join([chr(c) for c in generated])

print(generate_text(params_gen, model, start_string="""The farmhouse lingers, though averse to square
With the new city street it has to wear
A number in. But what about the brook
That held the house as in an elbow-crook?
I ask as one who knew the brook, its strength
And impulse, having dipped a finger length
And made it leap my knuckle, having tossed
A flower to try its currents where they crossed.
The meadow grass could be cemented down
From growing under pavements of a town;
The apple trees be sent to hearth-stone flame.
Is water wood to serve a brook the same?
How else dispose of an immortal force
No longer needed? Staunch it at its source
With cinder loads dumped down? The brook was thrown
Deep in a sewer dungeon under stone
In fetid darkness still to live and run â€”
And all for nothing it had ever done
Except forget to go in fear perhaps.
No one would know except for ancient maps
That such a brook ran water. But I wonder
If from its being kept forever under,
The thoughts may not have risen that so keep
This new-built city from both work and sleep.""", length=100))


Iteration 0, Loss: 6.21189546585083
Iteration 100, Loss: 4.883039474487305
Iteration 200, Loss: 4.875761985778809
Iteration 300, Loss: 4.875408172607422
Iteration 400, Loss: 4.875296592712402
Iteration 500, Loss: 4.875243186950684
Iteration 600, Loss: 4.875209808349609
Iteration 700, Loss: 4.875184535980225
Iteration 800, Loss: 4.875165939331055
Iteration 900, Loss: 4.875140190124512
Iteration 1000, Loss: 4.8751139640808105
Iteration 1100, Loss: 4.87531042098999
Iteration 1200, Loss: 4.875078201293945
Iteration 1300, Loss: 4.875041961669922
Iteration 1400, Loss: 4.875000953674316
Iteration 1500, Loss: 4.874953746795654
Iteration 1600, Loss: 4.874897480010986
Iteration 1700, Loss: 4.874874114990234
Iteration 1800, Loss: 4.874741554260254
Iteration 1900, Loss: 4.874608993530273
Iteration 2000, Loss: 4.874403476715088
Iteration 2100, Loss: 4.874251365661621
Iteration 2200, Loss: 4.873983383178711
Iteration 2300, Loss: 4.873186111450195
Iteration 2400, Loss: 4.879136085510254
Iteration 250

In [None]:
import jax.numpy as jnp
from jax import random, grad, jit
from flax import linen as nn
import optax

# Parameters
batch_size = 16
block_size = 32
learning_rate = 1e-3
max_iters = 1000
n_embd = 64
vocab_size = 256  # Assuming ASCII

rng_key = random.PRNGKey(0)

class Transformer(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = x[..., None]  # Adding an embedding dimension
        x = nn.Dense(n_embd)(x)
        x = nn.LayerNorm()(x)
        x = nn.SelfAttention(num_heads=2)(x)
        x = x.reshape((x.shape[0], x.shape[1], -1))  # Flattening the last dimensions
        x = nn.Dense(vocab_size)(x)
        return x

@jit
def softmax_cross_entropy(logits, targets):
    logits_reshaped = logits.reshape((-1, vocab_size))
    targets_reshaped = targets.reshape((-1,))
    logprobs = jax.nn.log_softmax(logits_reshaped)

    targets_one_hot = jax.nn.one_hot(targets_reshaped, vocab_size)

    # Element-wise multiplication and sum over the vocab_size dimension
    loss_values = -jnp.sum(targets_one_hot * logprobs, axis=-1)

    # Reshape loss values back to (batch_size, block_size)
    return loss_values.reshape((batch_size, block_size))

@jit
def compute_loss(params, x, y):
    logits = model.apply(params, x)
    loss_values = softmax_cross_entropy(logits, y)
    mean_loss = jnp.mean(loss_values)
    return mean_loss

@jit
def update(params, x, y, opt_state):
    opt_update = optimizer.update
    loss, grads = jax.value_and_grad(compute_loss)(params, x, y)  # removed `model` from the arguments
    updates, new_opt_state = opt_update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss


# Data (for demonstration purposes, use real data in practice)
data = jnp.array([i % vocab_size for i in range(10000)], dtype=jnp.int32)
def get_batch():
    idx = random.randint(rng_key, (batch_size,), 0, len(data) - block_size - 1)
    x = jnp.array([data[i:i+block_size] for i in idx])
    y = jnp.array([data[i+1:i+block_size+1] for i in idx])
    return x, y

# Training
model = Transformer()
params = model.init(rng_key, jnp.ones((batch_size, block_size)))
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

for iter in range(max_iters):
    x, y = get_batch()
    params, opt_state, loss = update(params, x, y, opt_state)
    if iter % 100 == 0:
        print(f"Iteration {iter}, Loss: {loss}")

# Simple text generation
def generate_text(params, model, start_token=0, length=100):
    generated = [start_token]

    # Initialize a sequence of length `block_size` filled with the `start_token`
    current_token = jnp.array([start_token] * block_size).reshape(1, block_size)

    for _ in range(length):
        logits = model.apply(params, current_token)  # Generate logits for the sequence
        next_token = jnp.argmax(logits[0, -1])
        generated.append(int(next_token))

        # Append the next_token to current_token sequence and use only the last `block_size` tokens
        current_token = jnp.concatenate([current_token, next_token.reshape(1, 1)], axis=1)[:, -block_size:]

    return generated


print(generate_text(params, model, start_token=0, length=100))


Iteration 0, Loss: 6.21189546585083
Iteration 100, Loss: 4.883039474487305
Iteration 200, Loss: 4.875761985778809
Iteration 300, Loss: 4.875408172607422
Iteration 400, Loss: 4.875296592712402
Iteration 500, Loss: 4.875243186950684
Iteration 600, Loss: 4.875209808349609
Iteration 700, Loss: 4.875184535980225
Iteration 800, Loss: 4.875165939331055
Iteration 900, Loss: 4.875140190124512
[0, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 5, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62]
