<a href="https://colab.research.google.com/github/Sravani-05/Assignment03/blob/main/297_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [18]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit
from flax import linen as nn
import optax

In [19]:
# Parameters
batch_size = 16
block_size = 32
learning_rate = 1e-3
max_iters = 1000
n_embd = 64
vocab_size = 256  # Assuming ASCII

rng_key = random.PRNGKey(0)

In [20]:
class Transformer(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = x[..., None]  # Adding an embedding dimension
        x = nn.Dense(n_embd)(x)
        x = nn.LayerNorm()(x)
        x = nn.SelfAttention(num_heads=2)(x)
        x = x.reshape((x.shape[0], x.shape[1], -1))  # Flattening the last dimensions
        x = nn.Dense(vocab_size)(x)
        return x

@jit
def softmax_cross_entropy(logits, targets):
    logits_reshaped = logits.reshape((-1, vocab_size))
    targets_reshaped = targets.reshape((-1,))
    logprobs = jax.nn.log_softmax(logits_reshaped)

    targets_one_hot = jax.nn.one_hot(targets_reshaped, vocab_size)

    # Element-wise multiplication and sum over the vocab_size dimension
    loss_values = -jnp.sum(targets_one_hot * logprobs, axis=-1)

    # Reshape loss values back to (batch_size, block_size)
    return loss_values.reshape((batch_size, block_size))

@jit
def compute_loss(params, x, y):
    logits = model.apply(params, x)
    loss_values = softmax_cross_entropy(logits, y)
    mean_loss = jnp.mean(loss_values)
    return mean_loss

@jit
def update(params, x, y, opt_state):
    opt_update = optimizer.update
    loss, grads = jax.value_and_grad(compute_loss)(params, x, y)  # removed `model` from the arguments
    updates, new_opt_state = opt_update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss


# Data (for demonstration purposes, use real data in practice)
data = jnp.array([i % vocab_size for i in range(10000)], dtype=jnp.int32)
def get_batch():
    idx = random.randint(rng_key, (batch_size,), 0, len(data) - block_size - 1)
    x = jnp.array([data[i:i+block_size] for i in idx])
    y = jnp.array([data[i+1:i+block_size+1] for i in idx])
    return x, y

# Training
model = Transformer()
params = model.init(rng_key, jnp.ones((batch_size, block_size)))
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

for iter in range(max_iters):
    x, y = get_batch()
    params, opt_state, loss = update(params, x, y, opt_state)
    if iter % 100 == 0:
        print(f"Iteration {iter}, Loss: {loss}")

# Additional utility function to convert a string to its ASCII representation
def string_to_ascii(input_str):
    return jnp.array([ord(c) for c in input_str], dtype=jnp.int32)

# Simple text generation
def generate_text(params, model, start_token=0, length=100):
    generated = [start_token]

    # Initialize a sequence of length `block_size` filled with the `start_token`
    current_token = jnp.array([start_token] * block_size).reshape(1, block_size)

    for _ in range(length):
        logits = model.apply(params, current_token)  # Generate logits for the sequence
        next_token = jnp.argmax(logits[0, -1])
        generated.append(int(next_token))

        # Append the next_token to current_token sequence and use only the last `block_size` tokens
        current_token = jnp.concatenate([current_token, next_token.reshape(1, 1)], axis=1)[:, -block_size:]

    return generated

# Initialize the model with a dummy input that matches the shape of our generation process
dummy_input = jnp.ones((1, block_size))
params_gen = model.init(rng_key, dummy_input)

# Update the params with trained weights
params_gen = params

def generate_text(params, model, start_string, length=100):
    start_tokens = string_to_ascii(start_string)
    generated = list(start_tokens)

    # If the initial tokens are fewer than block_size, pad them
    if len(start_tokens) < block_size:
        current_token = jnp.pad(start_tokens, (block_size - len(start_tokens), 0), mode='constant')
    else:
        current_token = start_tokens[-block_size:]  # Take the last `block_size` characters

    current_token = current_token.reshape(1, block_size)

    for _ in range(length):
        logits = model.apply(params, current_token)
        next_token = jnp.argmax(logits[0, -1])
        generated.append(int(next_token))

        # Use the most recent `block_size` tokens for the next step
        current_token = jnp.concatenate([current_token, next_token.reshape(1, 1)], axis=1)[:, -block_size:]

    return "".join([chr(c) for c in generated])

print(generate_text(params_gen, model, start_string="once upon a time", length=100))


Iteration 0, Loss: 6.21189546585083
Iteration 100, Loss: 4.883039474487305
Iteration 200, Loss: 4.875761985778809
Iteration 300, Loss: 4.875408172607422
Iteration 400, Loss: 4.875296592712402
Iteration 500, Loss: 4.875243186950684
Iteration 600, Loss: 4.875209331512451
Iteration 700, Loss: 4.875184535980225
Iteration 800, Loss: 4.875168800354004
Iteration 900, Loss: 4.875140190124512
once upon a timeõõõõõõõõõõõõõõõö


In [21]:
# Parameters
batch_size = 16
block_size = 32
learning_rate = 1e-3
max_iters = 4000
n_embd = 64
vocab_size = 256  # Assuming ASCII

rng_key = random.PRNGKey(0)

In [23]:
class Transformer(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = x[..., None]  # Adding an embedding dimension
        x = nn.Dense(n_embd)(x)
        x = nn.LayerNorm()(x)
        x = nn.SelfAttention(num_heads=2)(x)
        x = x.reshape((x.shape[0], x.shape[1], -1))  # Flattening the last dimensions
        x = nn.Dense(vocab_size)(x)
        return x

@jit
def softmax_cross_entropy(logits, targets):
    logits_reshaped = logits.reshape((-1, vocab_size))
    targets_reshaped = targets.reshape((-1,))
    logprobs = jax.nn.log_softmax(logits_reshaped)

    targets_one_hot = jax.nn.one_hot(targets_reshaped, vocab_size)

    # Element-wise multiplication and sum over the vocab_size dimension
    loss_values = -jnp.sum(targets_one_hot * logprobs, axis=-1)

    # Reshape loss values back to (batch_size, block_size)
    return loss_values.reshape((batch_size, block_size))

@jit
def compute_loss(params, x, y):
    logits = model.apply(params, x)
    loss_values = softmax_cross_entropy(logits, y)
    mean_loss = jnp.mean(loss_values)
    return mean_loss

@jit
def update(params, x, y, opt_state):
    opt_update = optimizer.update
    loss, grads = jax.value_and_grad(compute_loss)(params, x, y)  # removed `model` from the arguments
    updates, new_opt_state = opt_update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss


# Data (for demonstration purposes, use real data in practice)
data = jnp.array([i % vocab_size for i in range(10000)], dtype=jnp.int32)
def get_batch():
    idx = random.randint(rng_key, (batch_size,), 0, len(data) - block_size - 1)
    x = jnp.array([data[i:i+block_size] for i in idx])
    y = jnp.array([data[i+1:i+block_size+1] for i in idx])
    return x, y

# Training
model = Transformer()
params = model.init(rng_key, jnp.ones((batch_size, block_size)))
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

for iter in range(max_iters):
    x, y = get_batch()
    params, opt_state, loss = update(params, x, y, opt_state)
    if iter % 100 == 0:
        print(f"Iteration {iter}, Loss: {loss}")

# Additional utility function to convert a string to its ASCII representation
def string_to_ascii(input_str):
    return jnp.array([ord(c) for c in input_str], dtype=jnp.int32)

# Simple text generation
def generate_text(params, model, start_token=0, length=100):
    generated = [start_token]

    # Initialize a sequence of length `block_size` filled with the `start_token`
    current_token = jnp.array([start_token] * block_size).reshape(1, block_size)

    for _ in range(length):
        logits = model.apply(params, current_token)  # Generate logits for the sequence
        next_token = jnp.argmax(logits[0, -1])
        generated.append(int(next_token))

        # Append the next_token to current_token sequence and use only the last `block_size` tokens
        current_token = jnp.concatenate([current_token, next_token.reshape(1, 1)], axis=1)[:, -block_size:]

    return generated

# Initialize the model with a dummy input that matches the shape of our generation process
dummy_input = jnp.ones((1, block_size))
params_gen = model.init(rng_key, dummy_input)

# Update the params with trained weights
params_gen = params

def generate_text(params, model, start_string, length=100):
    start_tokens = string_to_ascii(start_string)
    generated = list(start_tokens)

    # If the initial tokens are fewer than block_size, pad them
    if len(start_tokens) < block_size:
        current_token = jnp.pad(start_tokens, (block_size - len(start_tokens), 0), mode='constant')
    else:
        current_token = start_tokens[-block_size:]  # Take the last `block_size` characters

    current_token = current_token.reshape(1, block_size)

    for _ in range(length):
        logits = model.apply(params, current_token)
        next_token = jnp.argmax(logits[0, -1])
        generated.append(int(next_token))

        # Use the most recent `block_size` tokens for the next step
        current_token = jnp.concatenate([current_token, next_token.reshape(1, 1)], axis=1)[:, -block_size:]

    return "".join([chr(c) for c in generated])

print(generate_text(params_gen, model, start_string="""O Captain! My Captain! our fearful trip is done;
The ship has weather'd every rack, the prize we sought is won;
The port is near, the bells I hear, the people all exulting,
While follow eyes the steady keel, the vessel grim and daring:

But O heart! heart! heart!
O the bleeding drops of red,
Where on the deck my Captain lies,
Fallen cold and dead.

O Captain! my Captain! rise up and hear the bells;
Rise up—for you the flag is flung—for you the bugle trills;
For you bouquets and ribbon'd wreaths—for you the shores a-crowding;
For you they call, the swaying mass, their eager faces turning;

O captain! dear father!
This arm beneath your head;
It is some dream that on the deck,
You've fallen cold and dead.

My Captain does not answer, his lips are pale and still;
My father does not feel my arm, he has no pulse nor will;
The ship is anchor'd safe and sound, its voyage closed and done""", length=100))


Iteration 0, Loss: 6.21189546585083
Iteration 100, Loss: 4.883039474487305
Iteration 200, Loss: 4.875761985778809
Iteration 300, Loss: 4.875408172607422
Iteration 400, Loss: 4.875296592712402
Iteration 500, Loss: 4.875243186950684
Iteration 600, Loss: 4.875209331512451
Iteration 700, Loss: 4.875184535980225
Iteration 800, Loss: 4.875168800354004
Iteration 900, Loss: 4.875140190124512
Iteration 1000, Loss: 4.875114440917969
Iteration 1100, Loss: 4.875144958496094
Iteration 1200, Loss: 4.875063419342041
Iteration 1300, Loss: 4.874737739562988
Iteration 1400, Loss: 4.875070571899414
Iteration 1500, Loss: 4.875022888183594
Iteration 1600, Loss: 4.874975204467773
Iteration 1700, Loss: 4.87492561340332
Iteration 1800, Loss: 4.874856948852539
Iteration 1900, Loss: 4.87477970123291
Iteration 2000, Loss: 4.874793529510498
Iteration 2100, Loss: 4.874538421630859
Iteration 2200, Loss: 4.87564754486084
Iteration 2300, Loss: 4.874076843261719
Iteration 2400, Loss: 4.873717308044434
Iteration 2500, 

In [24]:
import jax.numpy as jnp
from jax import random, grad, jit
from flax import linen as nn
import optax

# Parameters
batch_size = 16
block_size = 32
learning_rate = 1e-3
max_iters = 1000
n_embd = 64
vocab_size = 256  # Assuming ASCII

rng_key = random.PRNGKey(0)

class Transformer(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = x[..., None]  # Adding an embedding dimension
        x = nn.Dense(n_embd)(x)
        x = nn.LayerNorm()(x)
        x = nn.SelfAttention(num_heads=2)(x)
        x = x.reshape((x.shape[0], x.shape[1], -1))  # Flattening the last dimensions
        x = nn.Dense(vocab_size)(x)
        return x

@jit
def softmax_cross_entropy(logits, targets):
    logits_reshaped = logits.reshape((-1, vocab_size))
    targets_reshaped = targets.reshape((-1,))
    logprobs = jax.nn.log_softmax(logits_reshaped)

    targets_one_hot = jax.nn.one_hot(targets_reshaped, vocab_size)

    # Element-wise multiplication and sum over the vocab_size dimension
    loss_values = -jnp.sum(targets_one_hot * logprobs, axis=-1)

    # Reshape loss values back to (batch_size, block_size)
    return loss_values.reshape((batch_size, block_size))

@jit
def compute_loss(params, x, y):
    logits = model.apply(params, x)
    loss_values = softmax_cross_entropy(logits, y)
    mean_loss = jnp.mean(loss_values)
    return mean_loss

@jit
def update(params, x, y, opt_state):
    opt_update = optimizer.update
    loss, grads = jax.value_and_grad(compute_loss)(params, x, y)  # removed `model` from the arguments
    updates, new_opt_state = opt_update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss


# Data (for demonstration purposes, use real data in practice)
data = jnp.array([i % vocab_size for i in range(10000)], dtype=jnp.int32)
def get_batch():
    idx = random.randint(rng_key, (batch_size,), 0, len(data) - block_size - 1)
    x = jnp.array([data[i:i+block_size] for i in idx])
    y = jnp.array([data[i+1:i+block_size+1] for i in idx])
    return x, y

# Training
model = Transformer()
params = model.init(rng_key, jnp.ones((batch_size, block_size)))
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

for iter in range(max_iters):
    x, y = get_batch()
    params, opt_state, loss = update(params, x, y, opt_state)
    if iter % 100 == 0:
        print(f"Iteration {iter}, Loss: {loss}")

# Simple text generation
def generate_text(params, model, start_token=0, length=100):
    generated = [start_token]

    # Initialize a sequence of length `block_size` filled with the `start_token`
    current_token = jnp.array([start_token] * block_size).reshape(1, block_size)

    for _ in range(length):
        logits = model.apply(params, current_token)  # Generate logits for the sequence
        next_token = jnp.argmax(logits[0, -1])
        generated.append(int(next_token))

        # Append the next_token to current_token sequence and use only the last `block_size` tokens
        current_token = jnp.concatenate([current_token, next_token.reshape(1, 1)], axis=1)[:, -block_size:]

    return generated


print(generate_text(params, model, start_token=0, length=100))


Iteration 0, Loss: 6.21189546585083
Iteration 100, Loss: 4.883039474487305
Iteration 200, Loss: 4.875761985778809
Iteration 300, Loss: 4.875408172607422
Iteration 400, Loss: 4.875296592712402
Iteration 500, Loss: 4.875243186950684
Iteration 600, Loss: 4.875209331512451
Iteration 700, Loss: 4.875184535980225
Iteration 800, Loss: 4.875168800354004
Iteration 900, Loss: 4.875140190124512
[0, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 245, 246, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137]
