In [1]:
import jax
import jax.numpy as jnp
import haiku as hk
import optax

In [2]:
# Define the Transformer model
class Transformer(hk.Module):
    def __init__(self, num_layers, num_heads, d_model, d_ff):
        super().__init__()
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_ff = d_ff

        # Define the layers of the Transformer
        self.enc_layers = [
            hk.transform(
                lambda x: self._encoder_layer(x, self.num_heads, self.d_model, self.d_ff)
            )
            for _ in range(self.num_layers)
        ]

    def __call__(self, inputs):
        # Encoder input is the same as the decoder output
        encoder_input = inputs
        # Run the encoder layers
        for enc_layer in self.enc_layers:
            encoder_input = enc_layer(encoder_input)
        return encoder_input

    def _encoder_layer(self, inputs, num_heads, d_model, d_ff):
        # Multi-head self-attention layer
        x = hk.MultiHeadAttention(
            key_size=d_model // num_heads,
            value_size=d_model // num_heads,
            num_heads=num_heads,
            w_init=hk.initializers.TruncatedNormal(stddev=1.0 / jnp.sqrt(d_model)),
        )(inputs, inputs)
        # Add and normalize
        x = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x + inputs)
        # Feedforward layer
        x = hk.Linear(d_ff)(x)
        x = jax.nn.gelu(x)
        # Add and normalize
        x = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x + inputs)
        return x

In [3]:
# Define the loss function
def loss_fn(model, inputs, targets):
    # Run the model on the inputs
    logits = model(inputs)
    # Compute the cross-entropy loss
    ce_loss = jnp.mean(optax.softmax_cross_entropy(logits, targets))
    # Return the average loss across the batch
    return ce_loss

In [4]:
# Define the training step
@jax.jit
def train_step(model, optimizer, inputs, targets):
    # Compute the gradients of the loss with respect to the model parameters
    grads = jax.grad(loss_fn)(model, inputs, targets)
    # Update the model parameters using the optimizer
    updates, optimizer_state = optimizer.update(grads)
    model = optax.apply_updates(model, updates)
    # Return the updated model and optimizer state
    return model, optimizer_state

In [5]:
# Define the data
with open('Y.txt', 'r') as f:
    Y = f.read().splitlines()
with open('X.txt', 'r') as f:
    X = f.read().splitlines()

In [6]:
# Define the hyperparameters
num_layers = 4
num_heads = 8
d_model = 512
d_ff = 2048
learning_rate = 1e-4
batch_size = 32
num_epochs = 10

In [7]:
# Create the Transformer model
model = Transformer(num_layers=num_layers, num_heads=num_heads, d_model=d_model, d_ff=d_ff)

ValueError: All `hk.Module`s must be initialized inside an `hk.transform`.

In [None]:
# Create the optimizer
optimizer = optax.adam(learning_rate=learning_rate)

In [None]:
# Train the model
for epoch in range(num_epochs):
    # Shuffle the data
    perm = jax.random.permutation(jax.random.PRNGKey(0), len(X))