In [None]:
import jax
from jax import numpy as np, random
from jax import grad, jit, vmap
from flax import linen as nn
from flax.training import train_state
import optax

# Download and preprocess the Shakespeare dataset
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
shakespeare_text = jax.tree_map(lambda x: x.decode("utf-8"), jax.io.read_file(url))
vocab = sorted(set(shakespeare_text))
char_to_idx = {char: idx for idx, char in enumerate(vocab)}
idx_to_char = {idx: char for idx, char in enumerate(vocab)}
shakespeare_text = np.array([char_to_idx[char] for char in shakespeare_text], dtype=np.int32)

# Define the Transformer model
class Transformer(nn.Module):
    hidden_dim: int
    num_heads: int
    num_layers: int
    vocab_size: int

    def setup(self):
        self.embedding = nn.Embed(vocab_size=self.vocab_size, features=self.hidden_dim)
        self.transformer_blocks = [nn.TransformerBlock(d_model=self.hidden_dim, n_heads=self.num_heads)
                                   for _ in range(self.num_layers)]
        self.dense = nn.Dense(features=self.vocab_size)

    def __call__(self, inputs):
        embedded = self.embedding(inputs)
        for block in self.transformer_blocks:
            embedded = block(embedded)
        logits = self.dense(embedded)
        return logits

# Define training functions
def cross_entropy_loss(logits, targets):
    return -np.mean(np.sum(nn.log_softmax(logits) * targets, axis=-1))

def compute_metrics(logits, targets):
    loss = cross_entropy_loss(logits, targets)
    accuracy = np.mean(np.argmax(logits, axis=-1) == np.argmax(targets, axis=-1))
    return {'loss': loss, 'accuracy': accuracy}

def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn(params, batch['inputs'])
        loss = cross_entropy_loss(logits, batch['targets'])
        return loss, logits

    grad_fn = grad(loss_fn)
    grads, logits = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits, batch['targets'])
    return state, metrics

# Hyperparameters
hidden_dim = 256
num_heads = 4
num_layers = 4
batch_size = 64
learning_rate = 0.001
num_epochs = 10

# Create and initialize the model and optimizer
rng = random.PRNGKey(0)
input_shape = (batch_size, shakespeare_text.shape[0] // batch_size)
model = Transformer(hidden_dim=hidden_dim, num_heads=num_heads, num_layers=num_layers, vocab_size=len(vocab))
params = model.init(rng, random.PRNGKey(1), inputs=np.ones(input_shape, dtype=np.int32))
optimizer = optax.adam(learning_rate).create(params)

# Prepare the data in batches
def generate_batches(text, batch_size):
    for i in range(0, len(text) - batch_size, batch_size):
        inputs = text[i:i+batch_size]
        targets = text[i+1:i+batch_size+1]
        yield {'inputs': inputs, 'targets': targets}

batches = list(generate_batches(shakespeare_text, batch_size))

# Training loop
for epoch in range(num_epochs):
    state = train_state.TrainState.create(
        apply_fn=model.apply, params=optimizer.target, tx=optax.adam(learning_rate))
    
    for batch in batches:
        state, metrics = train_step(state, batch)
        print(f"Epoch {epoch + 1}, Loss: {metrics['loss']}, Accuracy: {metrics['accuracy']}")

# Save the trained model
jax.tree_map(lambda x: x.block_until_ready(), state.params)
model.save_pretrained("transformer_shakespeare_model")
