In [1]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state
from flax import struct
import optax
import numpy as np
from typing import Optional, Tuple, Any
import math
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
import torch
from transformers import AutoTokenizer
import os
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [58]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = '[PAD]'

In [60]:
tokenizer.pad_token_id

50256

In [131]:
# Configuration class for model parameters
from dataclasses import dataclass

@dataclass
class GPTConfig:
    vocab_size: int = tokenizer.vocab_size
    max_seq_len: int = 1024
    d_model: int = 768
    num_layers: int = 12
    num_heads: int = 12
    d_ff: int = 3072
    dropout_rate: float = 0.1
    lr: float = 6e-4
    warmup_steps: int = 700
    total_steps: int = 20000
    batch_size: int = 64
    gradient_accumulation_steps: int = 1
    mixed_precision: bool = False
    num_epochs: int = 1
    
config = GPTConfig()

In [106]:
class Attention(nn.Module):
    d_model: int = config.d_model
    num_heads: int = config.num_heads
    dropout_rate: float = config.dropout_rate
    
    def setup(self):
        self.head_size = self.d_model // self.num_heads
        self.d_Q = nn.Dense(features=self.head_size, use_bias=False)
        self.d_K = nn.Dense(features=self.head_size, use_bias=False)
        self.d_V = nn.Dense(features=self.head_size, use_bias=False)
        self.d_O = nn.Dense(features=self.d_model, use_bias=False)
        self.dropout = nn.Dropout(self.dropout_rate)
        
    def __call__(self, x, training=True):
            B,T,C = x.shape
            query = self.d_Q(x)
            key = self.d_K(x)
            value = self.d_V(x)
            
            weights = jnp.matmul(query, key.transpose(0,2, 1)) * (key.shape[-1] ** -0.5)
            mask = jnp.tril(jnp.ones((T,T)))
            mask = jnp.where(mask==0, -1e9, 1.0)
            weights = weights * mask
            weights = nn.softmax(weights, axis=-1)
            out = jnp.matmul(weights, value)
            out = self.d_O(out)
            out = self.dropout(out, deterministic=not training)
            return out

In [107]:
mask = jnp.tril(jnp.ones((128,128)))
mask = jnp.where(mask==0, -1e9, 1.0)

In [108]:
mask

Array([[ 1.e+00, -1.e+09, -1.e+09, ..., -1.e+09, -1.e+09, -1.e+09],
       [ 1.e+00,  1.e+00, -1.e+09, ..., -1.e+09, -1.e+09, -1.e+09],
       [ 1.e+00,  1.e+00,  1.e+00, ..., -1.e+09, -1.e+09, -1.e+09],
       ...,
       [ 1.e+00,  1.e+00,  1.e+00, ...,  1.e+00, -1.e+09, -1.e+09],
       [ 1.e+00,  1.e+00,  1.e+00, ...,  1.e+00,  1.e+00, -1.e+09],
       [ 1.e+00,  1.e+00,  1.e+00, ...,  1.e+00,  1.e+00,  1.e+00]],      dtype=float32, weak_type=True)

In [109]:
class MHA(nn.Module):
    d_model: int = config.d_model
    num_heads: int = config.num_heads
    dropout_rate: float = config.dropout_rate
    
    def setup(self):
        self.heads = [Attention(self.d_model, self.num_heads, self.dropout_rate) for _ in range(self.num_heads)]
        self.linear = nn.Dense(features=self.d_model)
        self.dropout = nn.Dropout(self.dropout_rate)
        
    def __call__(self, x, training=True):
        out = jnp.concatenate([head(x, training) for head in self.heads], axis=-1)
        out = self.linear(out)
        out = self.dropout(out, deterministic=not training)
        return out

In [110]:


def test_attention():
    attn = MHA()
    x = jnp.ones((2, 128, config.d_model))
    rng = jax.random.PRNGKey(0)
    params = attn.init(rng, x, training=True)
    out = attn.apply(params, x, training=True, rngs={'dropout': jax.random.PRNGKey(1)})
    
     # Check shape
    assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}"
        # assert out.shape == x.shape
    print("Shape test passed!", out.shape)
    

In [111]:
test_attention()

Shape test passed! (2, 128, 768)


In [112]:
class MLP(nn.Module):
    d_model: int = config.d_model
    d_ff: int = config.d_ff
    dropout_rate: float = config.dropout_rate
    
    def setup(self):
        self.fc1 = nn.Dense(features=self.d_ff)
        self.fc2 = nn.Dense(features=self.d_model)
        self.dropout = nn.Dropout(self.dropout_rate)
        
    def __call__(self, x, training=True):
        x = self.fc1(x)
        x = nn.gelu(x)
        x = self.fc2(x)
        x = nn.gelu(x)
        x = self.dropout(x, deterministic=not training)
        return x

In [113]:


def test_mlp():
    mlp = MLP()
    x = jnp.ones((2, 128, config.d_model))
    rng = jax.random.PRNGKey(0)
    params = mlp.init(rng, x, training=True)
    out = mlp.apply(params, x, training=True, rngs={'dropout': jax.random.PRNGKey(1)})
    
     # Check shape
    assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}"
        # assert out.shape == x.shape
    print("Shape test passed!", out.shape)
    

In [114]:
test_mlp()

Shape test passed! (2, 128, 768)


In [115]:
class TransformerBlock(nn.Module):
    d_model: int = config.d_model
    num_heads: int = config.num_heads
    d_ff: int = config.d_ff
    dropout_rate: float = config.dropout_rate
    
    def setup(self):
        self.attention = MHA(self.d_model, self.num_heads, self.dropout_rate)
        self.mlp = MLP(self.d_model, self.d_ff, self.dropout_rate)
        self.ln1 = nn.LayerNorm()
        self.ln2 = nn.LayerNorm()
        
    def __call__(self, x, training=True):
        x = x + self.attention(self.ln1(x), training)
        x = x + self.mlp(self.ln2(x), training)
        return x

In [116]:
class GPT(nn.Module):
    d_model: int = config.d_model
    num_heads: int = config.num_heads
    d_ff: int = config.d_ff
    dropout_rate: float = config.dropout_rate
    vocab_size: int = config.vocab_size
    seq_len: int = config.max_seq_len
    
    def setup(self):
        self.embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.d_model)
        self.positional_embedding = self.param(
            "positional_embeddings",  # name
            lambda key: jax.random.normal(key, (1, self.seq_len, self.d_model)) * 0.02
        )
        self.decoder = [TransformerBlock(self.d_model, self.num_heads, self.d_ff, self.dropout_rate) for _ in range(config.num_layers)]
        self.linear_out = nn.Dense(features=self.vocab_size)
        self.dropout = nn.Dropout(self.dropout_rate)
        
    def __call__(self, x, training=True):
        B,T = x.shape
        embeds = self.embedding_table(x)  # (B,T,d_model)
        C = embeds.shape[-1]
        pos_embeds = self.positional_embedding[:, :T, :]  # (1,T,d_model)
        x = embeds + pos_embeds  # (B,T,d_model)
        pad_mask = (x != tokenizer.pad_token_id).astype(jnp.float32)
        x = x * pad_mask
        for layer in self.decoder:
            x = layer(x, training=training)

        x = self.linear_out(x)
        x = self.dropout(x, deterministic=not training)
        return x   

In [117]:
def test_gpt():
    model = GPT()
    x = jnp.ones((2, 128), dtype=jnp.int32)
    rng = jax.random.PRNGKey(0)
    params = model.init(rng, x, training=True)
    out = model.apply(params, x, training=True, rngs={'dropout': jax.random.PRNGKey(1), "positional_embeddings": jax.random.PRNGKey(2)})
    
     # Check shape
    assert out.shape == (2, 128, config.vocab_size), f"Expected {(2, 128, config.vocab_size)}, got {out.shape}"
        # assert out.shape == x.shape
    print("Shape test passed!", out.shape)

In [118]:
test_gpt()

Shape test passed! (2, 128, 50257)


In [119]:
def create_learning_rate_schedule():
    """Create a learning rate schedule with warmup and cosine decay."""
    config = GPTConfig()
    def schedule(step):
        # Linear warmup
        warmup_ratio = jnp.minimum(1.0, step / config.warmup_steps)
        # Cosine decay after warmup
        decay_ratio = jnp.maximum(0.0, (step - config.warmup_steps) / (config.total_steps - config.warmup_steps))
        cosine_decay = 0.5 * (1 + jnp.cos(jnp.pi * decay_ratio))
        return config.lr * warmup_ratio * cosine_decay

    return schedule

In [120]:
def compute_ce_loss(logits, labels):
    """Compute cross-entropy loss."""
    labels = labels[:, 1:]
    logits = logits[:, :-1, :]  # Shift logits to align with labels
    
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
    return loss.mean()

In [124]:
def create_train_state(rng, config):
    """Create initial training state."""
    model = GPT()
    
    # Initialize parameters
    dummy_input = jnp.ones((1, config.max_seq_len), dtype=jnp.int32)
    params = model.init(rng, dummy_input)['params']
    
    # Create learning rate schedule
    lr_schedule = create_learning_rate_schedule()
    
    # Create optimizer
    tx = optax.adamw(
        learning_rate=lr_schedule,
        b1=0.9,
        b2=0.95,
        weight_decay=0.1
    )
    
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=tx
    )

In [123]:
state = create_train_state(jax.random.PRNGKey(0), config)
print(state)

In [125]:
@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({"params": params}, batch, training=True)
        loss = compute_ce_loss(logits, batch)
        return loss, logits
    
    (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    
    #update the parametrs
    state = state.apply_gradients(grads=grads)
    return state, loss
    

In [None]:
# JIT-compiled evaluation step
@jax.jit
def eval_step(state, batch):
    """Single evaluation step."""
    logits = state.apply_fn({'params': state.params}, batch, training=False)
    loss = compute_ce_loss(logits, batch)
    return loss

In [130]:
# JIT-compiled prediction step
@jax.jit
def predict_batch(state, batch):
    """Generate predictions for a batch using vmap."""
    return state.apply_fn({'params': state.params}, batch, training=False)


In [None]:
def train():
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=config.batch_size, 
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=4
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=config.batch_size, 
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=4
    )
    
    rng = jax.random.PRNGKey(0)
    train_state = create_train_state(rng, config)
    # Training loop
    num_epochs = config.num_epochs
    
    for epoch in range(num_epochs):
        
        # Training
        state = train_state.replace(step=0)  # Reset step counter for LR schedule
        train_losses = []
        train_accs = []
        
        # Process in batches with progress bar
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
        for batch in pbar:
            # Convert to JAX array
            batch = jnp.array(batch)
            
            state, loss = train_step(state, batch)
            train_losses.append(loss)
            train_accs.append(acc)
            
            # Update progress bar
            pbar.set_postfix({
                "loss": f"{loss:.4f}", 
                "acc": f"{acc:.4f}",
                "lr": f"{state.opt_state[1].hyperparams['learning_rate']:.6f}"
            })
        
        # Validation
        val_losses = []
        val_accs = []
        
        pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")
        for batch in pbar:
            batch = jnp.array(batch)
            loss, acc = eval_step(state, batch, tokenizer.pad_token_id)
            val_losses.append(loss)
            val_accs.append(acc)
            
            # Update progress bar
            pbar.set_postfix({
                "loss": f"{loss:.4f}", 
                "acc": f"{acc:.4f}"
            })
        
        # Calculate epoch metrics
        avg_train_loss = np.mean(train_losses)
        avg_train_acc = np.mean(train_accs)
        avg_val_loss = np.mean(val_losses)
        avg_val_acc = np.mean(val_accs)
        
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.4f}")
        print(f"  Val Loss: {avg_val_loss:.4f}, Val Acc: {avg_val_acc:.4f}")
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            # Save model checkpoint
            print(f"  New best model! Saving checkpoint...")
    
    # Generate some text
    print("\nGenerating text...")
    generated = generate_text(state, "The future of artificial intelligence", tokenizer, max_length=50)
    print(f"Generated: {generated}")
    
    