This is ChatGPT output

In [9]:
import jax
import jax.numpy as jnp
from jax import random
import numpy as np

class MultiHeadSelfAttention:
    def __init__(self, d_model, num_heads):
        self.d_model = d_model
        self.num_heads = num_heads
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.depth = d_model // num_heads

        self.wq = jax.nn.initializers.xavier_uniform()(random.PRNGKey(0), (d_model, d_model))
        self.wk = jax.nn.initializers.xavier_uniform()(random.PRNGKey(1), (d_model, d_model))
        self.wv = jax.nn.initializers.xavier_uniform()(random.PRNGKey(2), (d_model, d_model))
        self.wo = jax.nn.initializers.xavier_uniform()(random.PRNGKey(3), (d_model, d_model))

    def split_heads(self, x):
        batch_size = x.shape[0]
        return x.reshape(batch_size, -1, self.num_heads, self.depth).transpose(0, 2, 1, 3)

    def __call__(self, x):
        q = self.split_heads(jnp.dot(x, self.wq))
        k = self.split_heads(jnp.dot(x, self.wk))
        v = self.split_heads(jnp.dot(x, self.wv))

        scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) / jnp.sqrt(self.depth)
        attention_weights = jax.nn.softmax(scores, axis=-1)
        attention_output = jnp.matmul(attention_weights, v)

        attention_output = attention_output.transpose(0, 2, 1, 3).reshape(x.shape[0], -1, self.d_model)
        return jnp.dot(attention_output, self.wo)

class FeedForward:
    def __init__(self, d_model, d_ff):
        self.w1 = jax.nn.initializers.xavier_uniform()(random.PRNGKey(4), (d_model, d_ff))
        self.w2 = jax.nn.initializers.xavier_uniform()(random.PRNGKey(5), (d_ff, d_model))

    def __call__(self, x):
        return jax.nn.relu(jnp.dot(x, self.w1)).dot(self.w2)

class PositionalEncoding:
    def __init__(self, d_model, max_len):
        self.encoding = np.zeros((max_len, d_model))
        positions = np.arange(max_len)[:, np.newaxis]
        div_terms = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
        self.encoding[:, 0::2] = np.sin(positions * div_terms)
        self.encoding[:, 1::2] = np.cos(positions * div_terms)

    def __call__(self, x):
        seq_len = x.shape[1]
        return x + self.encoding[:seq_len, :]

class TransformerBlock:
    def __init__(self, d_model, num_heads, d_ff):
        self.attention = MultiHeadSelfAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model, d_ff)

    def __call__(self, x):
        attn_output = self.attention(x)
        x = x + attn_output
        ff_output = self.feed_forward(x)
        return x + ff_output

class Transformer:
    def __init__(self, d_model, num_heads, d_ff, num_layers, max_len, vocab_size):
        self.d_model = d_model
        self.embedding = jax.nn.initializers.xavier_uniform()(random.PRNGKey(6), (vocab_size, d_model))
        self.positional_encoding = PositionalEncoding(d_model, max_len)
        self.layers = [TransformerBlock(d_model, num_heads, d_ff) for _ in range(num_layers)]
        self.final_layer = jax.nn.initializers.xavier_uniform()(random.PRNGKey(7), (d_model, vocab_size))

    def __call__(self, x):
        # Perform embedding lookup instead of jnp.dot
        x = self.embedding[x]  # Shape: (batch_size, seq_len, d_model)
        x = self.positional_encoding(x)
        for layer in self.layers:
            x = layer(x)
        return jnp.dot(x, self.final_layer)

    # def __call__(self, x):
    #     x = jnp.dot(x, self.embedding)
    #     x = self.positional_encoding(x)
    #     for layer in self.layers:
    #         x = layer(x)
    #     return jnp.dot(x, self.final_layer)

# Example usage
def main():
    d_model = 128
    num_heads = 8
    d_ff = 512
    num_layers = 4
    max_len = 100
    vocab_size = 10000

    transformer = Transformer(d_model, num_heads, d_ff, num_layers, max_len, vocab_size)

    # Dummy input (batch_size=2, sequence_length=10)
    input_data = jnp.ones((2, 10), dtype=jnp.int32)
    output = transformer(input_data)

    print(output.shape)

# if __name__ == "__main__":
#     main()


In [10]:
main()

(2, 10, 10000)
