# 2-layer architecture for next token prediction task 

In [1]:
### Imports and Utility functions
import jax
import jax.numpy as jnp
from jax import random
import numpy as np
from flax import linen as nn
import optax
import flax.serialization

# Define utility functions for attention mechanism
def causal_mask(size):
    mask = np.tril(np.ones((size, size), dtype=np.bool_), k=0)
    return jnp.array(mask)

def cross_entropy_loss(logits, labels):
    log_probs = jax.nn.log_softmax(logits)
    return -jnp.sum(labels * log_probs) / labels.shape[0]


In [2]:
### Embedding dimensions
class MultiHeadSelfAttention(nn.Module):
    embed_dim: int
    num_heads: int

    def setup(self):
        self.qkv = nn.Dense(features=self.embed_dim * 3 * self.num_heads, use_bias=False)
        self.out = nn.Dense(features=self.embed_dim)

    def __call__(self, x, mask=None):
        batch_size, seq_length, _ = x.shape
        qkv = self.qkv(x)
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3, self.embed_dim)
        qkv = qkv.transpose((2, 0, 1, 3, 4))  # (num_heads, batch_size, seq_length, 3, embed_dim)
        q, k, v = qkv[:, :, :, 0, :], qkv[:, :, :, 1, :], qkv[:, :, :, 2, :]

        attn_weights = jnp.einsum('hbqd,hbkd->hbqk', q, k) / np.sqrt(self.embed_dim)

        if mask is not None:
            attn_weights = jnp.where(mask[None, :, None, :], attn_weights, -1e10)

        attn_weights = jax.nn.softmax(attn_weights, axis=-1)
        attn_output = jnp.einsum('hbqk,hbvd->hbqd', attn_weights, v)
        attn_output = attn_output.transpose((1, 2, 0, 3))  # (batch_size, seq_length, num_heads, embed_dim)
        attn_output = attn_output.reshape(batch_size, seq_length, self.num_heads * self.embed_dim)

        return self.out(attn_output)

class TransformerDecoderLayer(nn.Module):
    embed_dim: int
    num_heads: int

    def setup(self):
        self.self_attn = MultiHeadSelfAttention(embed_dim=self.embed_dim, num_heads=self.num_heads)
        self.ln = nn.LayerNorm()

    def __call__(self, x, mask=None):
        attn_output = self.self_attn(x, mask=mask)
        x = x + attn_output
        x = self.ln(x)
        return x

class TransformerDecoder(nn.Module):
    vocab_size: int
    layer_dims: list
    num_heads: list

    def setup(self):
        self.embedding = nn.Embed(num_embeddings=self.vocab_size, features=self.layer_dims[0])
        self.layers = [TransformerDecoderLayer(embed_dim=layer_dim, num_heads=num_heads) 
                       for layer_dim, num_heads in zip(self.layer_dims, self.num_heads)]
        self.ln = nn.LayerNorm()

    def __call__(self, x, mask=None):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x, mask=mask)
        x = self.ln(x)
        return x

class NextTokenPredictor(nn.Module):
    vocab_size: int
    layer_dims: list
    num_heads: list

    def setup(self):
        self.decoder = TransformerDecoder(
            vocab_size=self.vocab_size,
            layer_dims=self.layer_dims,
            num_heads=self.num_heads
        )
        self.out = nn.Dense(features=self.vocab_size)

    def __call__(self, x):
        seq_length = x.shape[1]
        mask = causal_mask(seq_length)
        decoder_output = self.decoder(x, mask=mask)
        logits = self.out(decoder_output)
        return logits


In [3]:
## Model definition

S = 5  # Cardinality of the alphabet
T = 25  # Sequence length
m1 = 2  # Heads in the first layer
m2 = 1  # Heads in the second layer
d_0 = S + T  # Embedding dimension
d_1 = (1 + m1) * d_0  # Dimension of the first layer
d_2 = (1 + m2) * d_1  # Dimension of the second layer

vocab_size = S
layer_dims = [d_0, d_1, d_2]  # Embedding dimensions for each layer
num_heads = [2, 1]  # Number of heads for each layer

model = NextTokenPredictor(
    vocab_size=vocab_size,
    layer_dims=layer_dims,
    num_heads=num_heads
)


In [4]:
# Data importing (Tri-grams)
# Load sequences from file
sequences = np.load('sequences.npy')

# Check the shape of the loaded sequences
print("Shape of loaded sequences:", sequences.shape)

# Embedding
embedded_sequences = np.zeros((sequences.shape[0], sequences.shape[1], d_0))
for i in range(sequences.shape[0]):
    for j in range(sequences.shape[1]):
        embedded_sequences[i, j, sequences[i, j]] = 1
        embedded_sequences[i, j, S + j] = 1
print("Shape of embedded sequences:", embedded_sequences.shape)

Shape of loaded sequences: (1000, 25)
Shape of embedded sequences: (1000, 25, 30)


In [5]:
## Training
import os

# Directory to save model parameters
save_dir = "saved_models"
os.makedirs(save_dir, exist_ok=True)

# Lists to store training loss and model parameters
training_losses = []
model_params_list = []

num_epochs = 1000
num_batches = 16
batch_size = embedded_sequences.shape[0] // num_batches

# Define optimizer with cosine decay schedule
num_train_steps = num_epochs * num_batches
lr_schedule = optax.cosine_decay_schedule(1.0, num_train_steps)
optimizer = optax.chain(optax.adam(learning_rate=lr_schedule), optax.clip_by_global_norm(1.0))

# Initialize model and optimizer state
rng = random.PRNGKey(0)
params = model.init(rng, jnp.zeros((1, T), dtype=jnp.int32))
optimizer_state = optimizer.init(params)

# Training loop
for epoch in range(num_epochs):
    epoch_loss = 0.0
    
    for batch_idx in range(num_batches):
        # Get batch
        batch_sequences = embedded_sequences[batch_idx * batch_size : (batch_idx + 1) * batch_size]
        batch_targets = embedded_sequences[batch_idx * batch_size + 1 : (batch_idx + 1) * batch_size + 1]
        
        # Compute gradients and loss
        def loss_fn(params):
            logits = model.apply(params, batch_sequences)
            loss = cross_entropy_loss(logits, batch_targets)
            return loss, logits
        
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, _), grads = grad_fn(params)
        epoch_loss += loss
        
        # Update parameters
        updates, optimizer_state = optimizer.update(grads, optimizer_state)
        params = optax.apply_updates(params, updates)
    
    # Calculate average epoch loss
    avg_epoch_loss = epoch_loss / num_batches
    training_losses.append(avg_epoch_loss)

    # Save model parameters
    if (epoch + 1) % 25 == 0:
        model_params_list.append(params)
        model_path = os.path.join(save_dir, f"model_epoch_{epoch+1}.params")
        with open(model_path, "wb") as f:
            f.write(flax.serialization.to_bytes(params))
    
    # Print epoch loss
    print(f"Epoch {epoch+1}, Loss: {avg_epoch_loss}")

# Save training losses to file
losses_path = os.path.join(save_dir, "training_losses.npy")
np.save(losses_path, np.array(training_losses))


RuntimeError: x2APIC is not supported