In [43]:
import jax
import jax.numpy as jnp
import numpy as np
import math
from flax import linen as nn
from flax.linen.initializers import constant, ones, zeros
from flax.training.train_state import TrainState
import torch
import torchvision
import torchvision.transforms as transforms
import optax
from tqdm.notebook import tqdm
import plotly.express as px
import time
from typing import Any, Sequence, Tuple, Optional, Dict

Word Embeddings

In [17]:
class Embedding(nn.Module):
    vocab_size: int
    emb_dim: int

    def setup(self):
        self.embed = nn.Embed(self.vocab_size, self.emb_dim)

    def __call__(self, x):
        return self.embed(x)

Positional Encoding



In [19]:
class PositionalEmbedding(nn.Module):
    max_seq_len: int
    embed_dim: int

    def setup(self):
        # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs
        pe = np.zeros((self.max_seq_len, self.embed_dim))
        position = np.arange(0, self.max_seq_len, dtype=np.float32)[:,None]
        div_term = np.exp(np.arange(0, self.embed_dim, 2) * (-math.log(10000.0) / self.embed_dim))
        pe[:, 0::2] = np.sin(position * div_term)
        pe[:, 1::2] = np.cos(position * div_term)
        pe = pe[None]
        self.pe = jax.device_put(pe)

    def __call__(self, x):
        x = x + self.pe[:, :x.shape[1]]
        return x

Self Attention

In [81]:
class MultiHeadAttention(nn.Module):
    embed_dim: int = 512
    num_heads: int = 8

    def setup(self):
        assert self.embed_dim % self.num_heads == 0
        self.single_head_dim = self.embed_dim // self.num_heads

        self.query = nn.Dense(self.embed_dim, use_bias=False)
        self.key = nn.Dense(self.embed_dim, use_bias=False)
        self.value = nn.Dense(self.embed_dim, use_bias=False)

        self.out = nn.Dense(self.embed_dim)

    def __call__(self, key, query, value, mask=None):
        batch_size = key.shape[0]
        seq_len_k = key.shape[1]
        seq_len_q = query.shape[1]

        # Linear projections
        q = self.query(query)  # (batch_size, seq_len_q, embed_dim)
        k = self.key(key)      # (batch_size, seq_len_k, embed_dim)
        v = self.value(value)  # (batch_size, seq_len_k, embed_dim)

        # Reshape for multi-head attention
        q = q.reshape(batch_size, seq_len_q, self.num_heads, self.single_head_dim)
        k = k.reshape(batch_size, seq_len_k, self.num_heads, self.single_head_dim)
        v = v.reshape(batch_size, seq_len_k, self.num_heads, self.single_head_dim)

        # Transpose to (batch_size, num_heads, seq_len, single_head_dim)
        q = q.transpose(0, 2, 1, 3)
        k = k.transpose(0, 2, 1, 3)
        v = v.transpose(0, 2, 1, 3)

        # Calculate attention scores
        # (batch_size, num_heads, seq_len_q, seq_len_k)
        attention = jnp.matmul(q, k.transpose(0, 1, 3, 2)) / jnp.sqrt(self.single_head_dim)

        # Apply mask if provided
        if mask is not None:
            attention = jnp.where(mask == 0, -1e9, attention)

        # Apply softmax
        attention_weights = nn.softmax(attention, axis=-1)

        # Calculate output
        # (batch_size, num_heads, seq_len_q, single_head_dim)
        out = jnp.matmul(attention_weights, v)

        # Transpose and reshape to (batch_size, seq_len_q, embed_dim)
        out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len_q, self.embed_dim)

        # Final linear projection
        out = self.out(out)

        return out

DyT

In [82]:
class DyT(nn.Module):
    num_features: int
    alpha_init: float = 0.5

    def setup(self):
        # Fixed initializer with shape parameter
        self.alpha = self.param('alpha', constant(self.alpha_init), ())  # Empty tuple for scalar
        self.weight = self.param('weight', ones, (self.num_features,))
        self.bias = self.param('bias', zeros, (self.num_features,))

    def __call__(self, x):
        x = nn.tanh(self.alpha * x)
        return x * self.weight + self.bias

Encoder

In [83]:
class TransformerBlock(nn.Module):
    embed_dim: int
    expansion_factor: int = 4
    n_heads: int = 8

    def setup(self):
        self.attention = MultiHeadAttention(self.embed_dim, self.n_heads)

        self.dyt1 = DyT(self.embed_dim)
        self.dyt2 = DyT(self.embed_dim)

        self.feed_forward = nn.Sequential([
            nn.Dense(self.embed_dim * self.expansion_factor),
            nn.relu,
            nn.Dense(self.embed_dim)
        ])

        self.dropout1 = nn.Dropout(rate=0.2)
        self.dropout2 = nn.Dropout(rate=0.2)

    def __call__(self, key,query,value, deterministic=False):

        attention_out = self.attention(key,query,value)
        attention_residual_out = attention_out + query
        dyt1_out = self.dropout1(self.dyt1(attention_residual_out), deterministic=deterministic)

        feed_forward_out = self.feed_forward(dyt1_out)
        feed_forward_residual_out = feed_forward_out + dyt1_out
        dyt2_out = self.dropout2(self.dyt2(feed_forward_residual_out), deterministic=deterministic)

        return dyt2_out



In [84]:
class TransformerEncoder(nn.Module):
    seq_len : int
    vocab_size: int
    embed_dim: int
    num_layers: int = 2
    expansion_factor: int = 4
    n_heads: int = 8

    def setup(self):
        self.embedding = Embedding(self.vocab_size, self.embed_dim)
        self.positional_embedding = PositionalEmbedding(self.seq_len, self.embed_dim)

        self.layers = [TransformerBlock(self.embed_dim, self.expansion_factor, self.n_heads) for _ in range(self.num_layers)]

    def __call__(self, x, deterministic=False):
        x = self.embedding(x)
        x = self.positional_embedding(x)

        for layer in self.layers:
            x = layer(x,x,x, deterministic=deterministic)

        return x


Decoder

In [85]:
class DecoderBlock(nn.Module):
    embed_dim: int
    expansion_factor: int = 4
    n_heads: int = 8

    def setup(self):
        self.attention = MultiHeadAttention(self.embed_dim, self.n_heads)
        self.dyt = DyT(self.embed_dim)
        self.dropout = nn.Dropout(rate=0.2)
        self.transformer_block = TransformerBlock(self.embed_dim, self.expansion_factor, self.n_heads)

    def __call__(self, key, query, x, mask=None, deterministic=False):
        # Self-attention on decoder input
        attention = self.attention(x, x, x, mask=mask)
        # Add & normalize
        value = self.dropout(self.dyt(attention + x), deterministic=deterministic)
        # Cross-attention with encoder output
        out = self.transformer_block(key, query, value, deterministic=deterministic)

        return out

In [108]:
class TransformerDecoder(nn.Module):
    target_vocab_size: int
    embed_dim: int
    seq_len: int
    num_layers: int = 2
    expansion_factor: int = 4
    n_heads: int = 8

    def setup(self):
        self.word_embedding = nn.Embed(self.target_vocab_size, self.embed_dim)
        self.positional_embedding = PositionalEmbedding(self.seq_len, self.embed_dim)

        self.layers = [DecoderBlock(self.embed_dim, self.expansion_factor, self.n_heads) for _ in range(self.num_layers)]
        self.fc_out = nn.Dense(self.target_vocab_size)
        self.dropout = nn.Dropout(rate=0.2)

    def __call__(self, x, enc_out, mask=None, deterministic=False):
        # Get the current sequence length
        batch_size, seq_len = x.shape

        x = self.word_embedding(x)
        x = self.positional_embedding(x)
        x = self.dropout(x, deterministic=deterministic)

        for layer in self.layers:
            # Only pass the mask for the current sequence length
            # This prevents the reshape error
            if mask is not None:
                # Slice the mask to match the current sequence length
                current_mask = mask[:, :, :seq_len, :seq_len]
            else:
                current_mask = None

            x = layer(enc_out, enc_out, x, mask=current_mask, deterministic=deterministic)

        x = self.fc_out(x)
        x = nn.softmax(x, axis=-1)
        return x

In [110]:
class Transformer(nn.Module):
    embed_dim: int
    src_vocab_size: int
    tgt_vocab_size: int
    seq_len: int
    num_layers: int = 2
    expansion_factor: int = 4
    n_heads: int = 8

    def setup(self):
        self.encoder = TransformerEncoder(self.seq_len, self.src_vocab_size, self.embed_dim, self.num_layers, self.expansion_factor, self.n_heads)
        self.decoder = TransformerDecoder(self.tgt_vocab_size, self.embed_dim, self.seq_len, self.num_layers, self.expansion_factor, self.n_heads)

    def make_trg_mask(self, trg):
        batch_size, trg_len = trg.shape
        # Create a lower triangular matrix of ones (causal mask)
        trg_mask = jnp.tril(jnp.ones((trg_len, trg_len)))
        # Broadcast the mask to the batch dimension
        trg_mask = jnp.broadcast_to(trg_mask[None, None, :, :], (batch_size, 1, trg_len, trg_len))
        return trg_mask

    def decode(self, src, trg, deterministic=False):
        trg_mask = self.make_trg_mask(trg)
        enc_out = self.encoder(src, deterministic=deterministic)
        out_labels = []
        batch_size,seq_len = src.shape[0],src.shape[1]

        out = trg

        for i in range(seq_len):
            out = self.decoder(out, enc_out, mask=trg_mask, deterministic=deterministic)
            out = out[:, -1, :]
            out = jnp.argmax(out, axis=-1)
            out_labels.append(out)
            out = out.reshape((batch_size, 1))

        return out_labels

    def __call__(self, src, trg, deterministic=False):
        trg_mask = self.make_trg_mask(trg)
        enc_out = self.encoder(src, deterministic=deterministic)
        out = self.decoder(trg, enc_out, mask=trg_mask, deterministic=deterministic)
        return out


In [111]:
# Test Transformer

src_vocab_size = 11
target_vocab_size = 11
num_layers = 6
seq_length= 12
embed_dim = 512
expansion_factor = 4
n_heads = 8
transformer = Transformer(embed_dim, src_vocab_size, target_vocab_size, seq_length, num_layers, expansion_factor, n_heads)
params = transformer.init(jax.random.PRNGKey(0), jnp.ones((1, seq_length), dtype=jnp.int32), jnp.ones((1, seq_length), dtype=jnp.int32), deterministic=True)
out = transformer.apply(params, jnp.ones((1, seq_length), dtype=jnp.int32), jnp.ones((1, seq_length), dtype=jnp.int32), deterministic=True)

In [115]:
src = jnp.array([[0, 2, 5, 6, 4, 3, 9, 5, 2, 9, 10, 1],
                 [0, 2, 8, 7, 3, 4, 5, 6, 7, 2, 10, 1]])
target = jnp.array([[0, 1, 7, 4, 3, 5, 9, 2, 8, 10, 9, 1],
                    [0, 1, 5, 6, 2, 4, 7, 6, 2, 8, 10, 1]])

In [116]:
out = transformer.apply(params, src, target, deterministic=True)

In [117]:
out.shape

(2, 12, 11)