In [1]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state
from flax import struct
import optax
import numpy as np
from typing import Optional, Tuple, Any
import math
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
import torch
from transformers import AutoTokenizer
import os
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Configuration class for model parameters
from dataclasses import dataclass

@dataclass
class GPTConfig:
    vocab_size: int = 50257
    max_seq_len: int = 1024
    d_model: int = 768
    num_layers: int = 12
    num_heads: int = 12
    d_ff: int = 3072
    dropout_rate: float = 0.1

config = GPTConfig()

In [13]:
class Attention(nn.Module):
    d_model: int = config.d_model
    num_heads: int = config.num_heads
    dropout_rate: float = config.dropout_rate
    
    def setup(self):
        self.head_size = self.d_model // self.num_heads
        self.d_Q = nn.Dense(features=self.head_size, use_bias=False)
        self.d_K = nn.Dense(features=self.head_size, use_bias=False)
        self.d_V = nn.Dense(features=self.head_size, use_bias=False)
        self.d_O = nn.Dense(features=self.d_model, use_bias=False)
        self.dropout = nn.Dropout(self.dropout_rate)
        
    def __call__(self, x, training=True):
            B,T,C = x.shape
            query = self.d_Q(x)
            key = self.d_K(x)
            value = self.d_V(x)
            
            weights = jnp.matmul(query, key.transpose(0,2, 1)) * (key.shape[-1] ** -0.5)
            mask = jnp.tril(jnp.ones((128,128)))
            mask = jnp.where(mask==0, -1e9, 1.0)
            weights = weights + mask
            weights = nn.softmax(weights, axis=-1)
            out = jnp.matmul(weights, value)
            out = self.d_O(out)
            out = self.dropout(out, deterministic=not training)
            return out

In [11]:
mask = jnp.tril(jnp.ones((128,128)))
mask = jnp.where(mask==0, -1e9, 1.0)

In [12]:
mask

Array([[ 1.e+00, -1.e+09, -1.e+09, ..., -1.e+09, -1.e+09, -1.e+09],
       [ 1.e+00,  1.e+00, -1.e+09, ..., -1.e+09, -1.e+09, -1.e+09],
       [ 1.e+00,  1.e+00,  1.e+00, ..., -1.e+09, -1.e+09, -1.e+09],
       ...,
       [ 1.e+00,  1.e+00,  1.e+00, ...,  1.e+00, -1.e+09, -1.e+09],
       [ 1.e+00,  1.e+00,  1.e+00, ...,  1.e+00,  1.e+00, -1.e+09],
       [ 1.e+00,  1.e+00,  1.e+00, ...,  1.e+00,  1.e+00,  1.e+00]],      dtype=float32, weak_type=True)

In [29]:


def test_attention():
    attn = Attention()
    x = jnp.ones((2, 128, config.d_model))
    rng = jax.random.PRNGKey(0)
    params = attn.init(rng, x, training=True)
    out = attn.apply(params, x, training=True, rngs={'dropout': jax.random.PRNGKey(1)})
    
     # Check shape
    assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}"
        # assert out.shape == x.shape
    print("Shape test passed!", out.shape)
    

In [30]:
test_attention()

Shape test passed! (2, 128, 768)


In [24]:
class MLP(nn.Module):
    d_model: int = config.d_model
    d_ff: int = config.d_ff
    dropout_rate: float = config.dropout_rate
    
    def setup(self):
        self.fc1 = nn.Dense(features=self.d_ff)
        self.fc2 = nn.Dense(features=self.d_model)
        self.dropout = nn.Dropout(self.dropout_rate)
        
    def __call__(self, x, training=True):
        x = self.fc1(x)
        x = nn.gelu(x)
        x = self.fc2(x)
        x = nn.gelu(x)
        x = self.dropout(x, deterministic=not training)
        return x

In [27]:


def test_mlp():
    mlp = MLP()
    x = jnp.ones((2, 128, config.d_model))
    rng = jax.random.PRNGKey(0)
    params = mlp.init(rng, x, training=True)
    out = mlp.apply(params, x, training=True, rngs={'dropout': jax.random.PRNGKey(1)})
    
     # Check shape
    assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}"
        # assert out.shape == x.shape
    print("Shape test passed!", out.shape)
    

In [28]:
test_mlp()

Shape test passed! (2, 128, 768)
