In [1]:
import pathlib
import random
import string
import re
import numpy as np

import jax.numpy as jnp
import optax

from flax import nnx

import tiktoken
import grain.python as grain
import tqdm

from datasets import load_dataset


In [52]:
class LuongAttention(nnx.Module):
    def __init__(self, hidden_size, src_vocab_size, tgt_vocab_size, rngs=nnx.Rngs):
        self.source_embedding = nnx.Embed(src_vocab_size, hidden_size, rngs=rngs)
        self.target_embedding = nnx.Embed(tgt_vocab_size, hidden_size, rngs=rngs)

        # Initialize RNNs for encoder and decoder
        self.encoder = nnx.RNN(
                nnx.GRUCell(hidden_size, hidden_size, rngs=rngs),
                return_carry=True
        )
        self.decoder = nnx.RNN(
                nnx.GRUCell(hidden_size, hidden_size, rngs=rngs),
                return_carry=True
        )

        self.W_c = nnx.Linear(hidden_size * 2, hidden_size, rngs=rngs)
        self.W_y = nnx.Linear(hidden_size, tgt_vocab_size, rngs=rngs)

    def __call__(self, source, target, h_init):
        # Compute embeddings; shape: (batch, seq_len, features)
        source_seq = self.source_embedding(source)
        target_seq = self.target_embedding(target)

        # Encoder and decoder passes
        h_t, h_final = self.encoder(source_seq, initial_carry=h_init)
        s_t, s_final = self.decoder(target_seq, initial_carry=h_final)

        # Handle shape transformations properly
        # If outputs have shape (seq_len, batch, hidden), transpose to (batch, seq_len, hidden)
        if h_t.shape[0] != source.shape[0]:
            h_t = jnp.transpose(h_t, (1, 0, 2))  # (batch, src_seq_len, hidden)
            s_t = jnp.transpose(s_t, (1, 0, 2))  # (batch, tgt_seq_len, hidden)

        # Compute attention scores: (batch, tgt_seq_len, hidden) @ (batch, hidden, src_seq_len)
        e_t_i = jnp.matmul(s_t, jnp.transpose(h_t, (0, 2, 1)))

        # Apply softmax to get alignment weights
        alignment_scores = nnx.softmax(e_t_i, axis=-1)

        # Compute context vectors: (batch, tgt_seq_len, src_seq_len) @ (batch, src_seq_len, hidden)
        c_t = jnp.matmul(alignment_scores, h_t)

        # Compute combined representation
        s_hat_t = nnx.tanh(self.W_c(jnp.concatenate([s_t, c_t], axis=-1)))

        # Project to vocabulary space
        y_t = self.W_y(s_hat_t)

        # Return in expected format
        return jnp.transpose(y_t, (1, 0, 2))  # (tgt_seq_len, batch, tgt_vocab_size)

In [53]:
# Inicializar red
net = LuongAttention(
        hidden_size=256,
        src_vocab_size=3371,
        tgt_vocab_size=2810,
        rngs=nnx.Rngs(42)
)

# Crear entradas de prueba
source = jnp.array([[10, 23, 5]])  # (batch=1, src_seq_len=3)
target = jnp.array([[4, 9]])       # (batch=1, tgt_seq_len=2)

# Create initial hidden state with proper dimensions
batch_size = source.shape[0]
h_init = jnp.zeros((batch_size, 256))  # (batch=1, hidden_size)

# Forward pass
output = net(source, target, h_init)
print(output.shape)  # Debería mostrar (2, 1, 2810)

ValueError: axis 2 is out of bounds for array of dimension 2