# 2-layer architecture for next token prediction task 

In [18]:
### Imports and Utility functions
import jax
import jax.numpy as jnp
from jax import random
import numpy as np
from flax import linen as nn
import optax
import flax.serialization

# Define utility functions for attention mechanism
def causal_mask(size):
    mask = np.tril(np.ones((size, size), dtype=np.bool_), k=0)
    return jnp.array(mask)

def cross_entropy_loss(logits, labels):
    log_probs = jax.nn.log_softmax(logits)
    return -jnp.sum(labels * log_probs) / labels.shape[0]


In [19]:
class MultiHeadSelfAttention(nn.Module):
    embed_dim: int
    num_heads: int

    def setup(self):
        self.head_dim = self.embed_dim 
        self.qkv = nn.Dense(features=self.embed_dim * 3 * self.num_heads, use_bias=False)
        self.out = nn.Dense(features=self.embed_dim)

    def __call__(self, x, mask=None):
        batch_size, seq_length,_ = x.shape
        qkv = self.qkv(x) #qkb are not the params, theyre W_Q, W_K, W_V applied to the batch !
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3, self.head_dim)
        qkv = qkv.transpose((2, 0, 1, 3, 4))  # (num_heads, batch_size, seq_length, 3, head_dim)
        q, k, v = qkv[:, :, :, 0, :], qkv[:, :, :, 1, :], qkv[:, :, :, 2, :]
        attn_weights = jnp.einsum('hbqd,hbkd->hbqk', q, k) / jnp.sqrt(self.head_dim) # einstein summation
        #then check if we want to normalize or not / jnp.sqrt(self.head_dim)

        if mask is not None:
            attn_weights = jnp.where(mask[None, None, :, :], attn_weights, -1e10)

        attn_weights = jax.nn.softmax(attn_weights, axis=-1) #axis=-1 is the last axis i.e. 'row-wise'
        attn_output = jnp.einsum('hbqk,hbvd->hbqd', attn_weights, v) # (num_heads, batch_size, seq_length, head_dim)
        attn_output = attn_output.transpose((1, 2, 0, 3))  # (batch_size, seq_length, num_heads, head_dim)
        attn_output = attn_output.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
        return self.out(attn_output)


class TransformerDecoderLayer(nn.Module):
    embed_dim: int
    num_heads: int

    def setup(self):
        self.self_attn = MultiHeadSelfAttention(embed_dim=self.embed_dim, num_heads=self.num_heads)
        self.ln = nn.LayerNorm()

    def __call__(self, x, mask=None):
        attn_output = self.self_attn(x, mask=mask)
        x = x + attn_output
        x = self.ln(x)
        return x

class TransformerDecoder(nn.Module):
    #vocab_size: int
    layer_dims: list
    num_heads: list

    def setup(self):
        #self.embedding = nn.Embed(num_embeddings=self.vocab_size, features=self.layer_dims[0]) # no embedding!
        self.layers = [TransformerDecoderLayer(embed_dim=layer_dim, num_heads=num_heads) 
                       for layer_dim, num_heads in zip(self.layer_dims, self.num_heads)]
        self.ln = nn.LayerNorm()

    def __call__(self, x, mask=None):
        #x = self.embedding(x)
        for layer in self.layers:
            x = layer(x, mask=mask)
        x = self.ln(x)
        return x

class NextTokenPredictor(nn.Module):
    vocab_size: int
    layer_dims: list
    num_heads: list

    def setup(self):
        self.decoder = TransformerDecoder(
            #vocab_size=self.vocab_size,
            layer_dims=self.layer_dims,
            num_heads=self.num_heads
        )
        self.out = nn.Dense(features=self.vocab_size)

    def __call__(self, x):
        seq_length = x.shape[1]
        mask = causal_mask(seq_length)
        decoder_output = self.decoder(x, mask=mask)
        logits = self.out(decoder_output)
        return logits


In [20]:
## Model definition

S = 5  # Cardinality of the alphabet
T = 25  # Sequence length
m1 = 2  # Heads in the first layer
m2 = 1  # Heads in the second layer
num_heads = [m1, m2]  # Number of heads for each layer

d_0 = S + T # Dimension of the input sequence
d_1 = (1 + m1) * d_0  
d_2 = (1 + m2) * d_1  # embedding dimension  
# definition of d_l in the paper. 



vocab_size = S
layer_dims = [d_2, d_2]  

model = NextTokenPredictor(
    vocab_size=vocab_size,
    layer_dims=layer_dims,
    num_heads=num_heads
)


In [21]:
# Data Generation 
    
def create_3gram_transition_matrix(S):
    """
    Create a transition matrix for a 3-gram model by sampling from a Dirichlet prior.

    Returns:
    dict: A dictionary where keys are 2-grams (tuples of integers) and values are the probability distributions
          over the next words (integers from 0 to S-1).
    """
    # Create a list of all possible 2-grams from the vocabulary
    two_grams = [(i, j) for i in range(S) for j in range(S)]

    transition_matrix = {}     # Initialize the transition matrix as a dictionary
    alpha = 1.0     # Dirichlet parameter alpha
 
    for two_gram in two_grams:
        next_word_probs = np.random.dirichlet([alpha] * S)         # Sample a probability distribution over the next words
        transition_matrix[two_gram] = next_word_probs
    return transition_matrix


def generate_sequence(transition_matrix, T):
    """
    Generate a sequence of length T given a transition matrix.
    
    Parameters:
    transition_matrix (dict): The transition matrix where keys are 2-grams (tuples of integers)
                              and values are the probability distributions over the next words.
    T (int): Length of the sequence to generate.
    
    Returns:
    list: A list of integers representing the generated sequence.
    """
    two_grams = list(transition_matrix.keys())     # Extract the list of 2-grams from the transition matrix
    
    # Randomly choose an initial 2-gram
    current_2gram = two_grams[np.random.choice(len(two_grams))]
    
    sequence = list(current_2gram)     # Initialize the sequence with the chosen 2-gram
    
    # Generate the sequence
    for _ in range(T - 2):
        next_word_probs = transition_matrix[current_2gram]      # Get the next word probability distribution for the current 2-gram
        next_word = np.random.choice(range(len(next_word_probs)), p=next_word_probs)         # Sample the next word 
        sequence.append(next_word)
        # Update the current 2-gram
        current_2gram = (current_2gram[1], next_word)
    return sequence

Shape of loaded sequences: (25000, 25)
Shape of embedded sequences: (25000, 25, 180)


In [None]:
# embedding 
def embed(sequences):
    embedded_sequences = np.zeros((sequences.shape[0], sequences.shape[1], d_2), dtype=np.int32)
    for i in range(sequences.shape[0]):
        for j in range(sequences.shape[1]):
            embedded_sequences[i, j, sequences[i, j]] = 1
            embedded_sequences[i, j, S + j] = 1
    return embedded_sequences

In [22]:
## Training
import os

# Directory to save model parameters
save_dir = "saved_models"
os.makedirs(save_dir, exist_ok=True)

# Lists to store training loss and model parameters
training_losses = []
model_params_list = []

num_epochs = 2048
num_batches = 64
batch_size = 1024

# Define optimizer with cosine decay schedule
num_train_steps = num_epochs * num_batches
lr_schedule = optax.cosine_decay_schedule(0.3, num_train_steps)
optimizer = optax.chain(optax.adam(learning_rate=lr_schedule), optax.clip_by_global_norm(1.0))

# Initialize model and optimizer state
rng = random.PRNGKey(0)
params = model.init(rng, jnp.zeros((1, T, d_2), dtype=jnp.int32))

optimizer_state = optimizer.init(params)

# Training loop
for epoch in range(num_epochs):
    epoch_loss = 0.0
    transition_matrix = create_3gram_transition_matrix(S) # 3-gram transition fixed for the epoch
    
    for batch_idx in range(num_batches):
        sequences = np.array([generate_sequence(transition_matrix, T) for _ in range(batch_size)])


      # Generate target sequences by shifting the input sequences
        target_sequences = np.zeros_like(sequences)
        target_sequences[:, :-1] = sequences[:, 1:]
        
        # Convert target sequences to one-hot encoding
        sequences_onehot = np.zeros((sequences.shape[0], sequences.shape[1], S), dtype=np.int32)
        target_sequences_onehot = np.zeros((target_sequences.shape[0], target_sequences.shape[1], S), dtype=np.int32)
        for i in range(sequences.shape[0]):
            for j in range(sequences.shape[1]):
                sequences_onehot[i, j, sequences[i, j]] = 1
                target_sequences_onehot[i, j, target_sequences[i, j]] = 1
        
        # Embed input sequences
        embedded_sequences = embed(sequences)
        
        # Convert embedded sequences and targets to JAX arrays
        batch_sequences = jnp.array(embedded_sequences)
        batch_targets = jnp.array(target_sequences_onehot)

        
        # Compute gradients and loss
        def loss_fn(params):
            logits = model.apply(params, batch_sequences)
            loss = cross_entropy_loss(logits, batch_targets)
            return loss, logits
        
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, _), grads = grad_fn(params)
        epoch_loss += loss
        
        # Update parameters
        updates, optimizer_state = optimizer.update(grads, optimizer_state)
        params = optax.apply_updates(params, updates)
    
    # Calculate average epoch loss
    avg_epoch_loss = epoch_loss / num_batches
    training_losses.append(avg_epoch_loss)

    # Save model parameters
    if (epoch + 1) % 25 == 0:
        model_params_list.append(params)
        model_path = os.path.join(save_dir, f"model_epoch_{epoch+1}.params")
        with open(model_path, "wb") as f:
            f.write(flax.serialization.to_bytes(params))
    
    # Print epoch loss
    print(f"Epoch {epoch+1}, Loss: {avg_epoch_loss}")

# Save training losses to file
losses_path = os.path.join(save_dir, "training_losses.npy")
np.save(losses_path, np.array(training_losses))


Epoch 1, Loss: 51.807762145996094
Epoch 2, Loss: 41.703086853027344
Epoch 3, Loss: 40.71582794189453
Epoch 4, Loss: 40.68227767944336
Epoch 5, Loss: 44.09806823730469
Epoch 6, Loss: 41.85529327392578
Epoch 7, Loss: 42.92106628417969
Epoch 8, Loss: 42.15922546386719
Epoch 9, Loss: 42.38582229614258
