In [5]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit
import optax

# Configuration
max_len = 50
vocab_size = 256
hidden_size = 256
num_heads = 4
num_layers = 2
batch_size = 32
learning_rate = 3e-4
num_epochs = 10

# Load Shakespeare dataset from a local text file
def load_shakespeare_dataset(file_path):
    with open(file_path, 'r') as file:
        text = file.read()
    return text

# Preprocess text data
def preprocess_text(text, vocab_size):
    vocab = sorted(set(text))
    char_to_idx = {char: idx for idx, char in enumerate(vocab)}
    idx_to_char = {idx: char for idx, char in enumerate(vocab)}

    text_indices = jnp.array([char_to_idx[char] for char in text], dtype=jnp.int32)
    return text_indices, char_to_idx, idx_to_char

# Positional encoding
def positional_encoding(position, d_model):
    angle_rates = 1 / jnp.power(10000, (2.0 / d_model) * jnp.arange(0, d_model, 2))
    angle_rads = (position[:, jnp.newaxis] * angle_rates[jnp.newaxis, :])
    sines = jnp.sin(angle_rads)
    cosines = jnp.cos(angle_rads)
    pos_encoding = jnp.stack([sines, cosines], axis=-1)
    pos_encoding = pos_encoding.reshape((position.shape[0], -1))
    return pos_encoding

# Dot product attention
def dot_product_attention(Q, K, V, mask=None):
    key_dim = Q.shape[-1]
    Q = Q / jnp.sqrt(key_dim)

    unnormalized_logits = jnp.einsum('...qd,...kd->...qk', Q, K)

    if mask is not None:
        unnormalized_logits = jnp.where(mask, unnormalized_logits, -jnp.inf)

    attention_weights = jax.nn.softmax(unnormalized_logits, axis=-1)
    output = jnp.einsum('...qk,...kv->...qv', attention_weights, V)  # Adjust dimensions here
    return output

# Multi-head attention
def multihead_attention(Q, K, V, mask=None, num_heads=num_heads):
    head_dim = hidden_size // num_heads
    Q, K, V = [jax.lax.split(x, num_heads, axis=-1) for x in (Q, K, V)]

    output = []
    for q, k, v in zip(Q, K, V):
        output.append(dot_product_attention(q, k, v, mask))

    output = jnp.concatenate(output, axis=-1)
    return output


# Transformer encoder layer
def transformer_encoder_layer(x, weights, mask=None):
    q = jnp.dot(x, weights[0])
    k = v = x
    attention_output = multihead_attention(q, k, v, mask)
    x = x + attention_output
    x = x + jnp.dot(x, weights[1])
    return x

# Transformer decoder layer
def transformer_decoder_layer(x, encoder_output, weights, mask=None):
    q = jnp.dot(x, weights[0])
    k = v = x
    attention_output = multihead_attention(q, k, v, mask)
    x = x + attention_output
    x = x + jnp.dot(x, weights[1])
    q = x
    k = encoder_output
    v = encoder_output
    attention_output = multihead_attention(q, k, v, mask)
    x = x + attention_output
    x = x + jnp.dot(x, weights[2])
    return x

# Transformer model
def transformer_model(encoder_inputs, decoder_inputs, params):
    encoder_mask = None  # Optional: Add mask for padded values
    decoder_mask = None  # Optional: Add mask for future tokens
    
    # Embedding
    encoder_embedding = jnp.dot(encoder_inputs, params['embed_weights'])
    decoder_embedding = jnp.dot(decoder_inputs, params['embed_weights'])
    
    # Positional Encoding
    encoder_embedding += params['positional_encodings'][:encoder_inputs.shape[1], :]
    decoder_embedding += params['positional_encodings'][:decoder_inputs.shape[1], :]
    
    # Encoder
    for layer_weights in params['encoder_weights']:
        encoder_embedding = transformer_encoder_layer(encoder_embedding, layer_weights, encoder_mask)
    
    # Decoder
    for layer_weights in params['decoder_weights']:
        decoder_embedding = transformer_decoder_layer(decoder_embedding, encoder_embedding, layer_weights, decoder_mask)
    
    # Output layer
    logits = jnp.dot(decoder_embedding, params['output_weights'])
    return logits

# Loss function
def cross_entropy_loss(logits, targets):
    return -jnp.mean(jax.nn.log_softmax(logits) * targets)

# Initialize optimizer
optimizer = optax.adam(learning_rate)

@jit
def train_step(params, encoder_inputs, decoder_inputs, targets):
    logits = transformer_model(encoder_inputs, decoder_inputs, params)
    targets_one_hot = jax.nn.one_hot(targets, vocab_size)
    loss = cross_entropy_loss(logits, targets_one_hot)
    grads = grad(loss)(params)
    updates, _ = optimizer.update(grads, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, loss

# Load Shakespeare dataset
shakespeare_text = load_shakespeare_dataset('shakespeare.txt')

# Preprocess text data
shakespeare_indices, char_to_idx, idx_to_char = preprocess_text(shakespeare_text, vocab_size)

# Initialize parameters for the Transformer model
rng = random.PRNGKey(42)
params = {
    'embed_weights': random.normal(rng, (vocab_size, hidden_size)),
    'positional_encodings': positional_encoding(jnp.arange(max_len), hidden_size),
    'encoder_weights': [random.normal(rng, (hidden_size, hidden_size)) for _ in range(num_layers)],
    'decoder_weights': [random.normal(rng, (hidden_size, hidden_size)) for _ in range(num_layers)],
    'output_weights': random.normal(rng, (hidden_size, vocab_size)),
}

# Training loop
for epoch in range(num_epochs):
    for batch_start in range(0, len(shakespeare_indices) - max_len, batch_size):
        encoder_batch = shakespeare_indices[batch_start:batch_start + batch_size]
        decoder_batch = shakespeare_indices[batch_start + 1:batch_start + batch_size + 1]
        targets_batch = shakespeare_indices[batch_start + 2:batch_start + batch_size + 2]

        params, loss = train_step(params, encoder_batch, decoder_batch, targets_batch)
        # Optionally log/print


TypeError: dot_general requires contracting dimensions to have the same shape, got (32,) and (256,).