In [1]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state
from flax import struct
import optax
import numpy as np
from typing import Optional, Tuple, Any
import math
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
import torch
from transformers import AutoTokenizer
import os
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Configuration class for model parameters
from dataclasses import dataclass

@dataclass
class GPTConfig:
    vocab_size: int = 50257
    max_seq_len: int = 1024
    d_model: int = 768
    num_layers: int = 12
    num_heads: int = 12
    d_ff: int = 3072
    dropout_rate: float = 0.1

config = GPTConfig()

In [3]:
class Attention(nn.Module):
    d_model: int = config.d_model
    num_heads: int = config.num_heads
    dropout_rate: float = config.dropout_rate
    
    def setup(self):
        self.head_size = self.d_model // self.num_heads
        self.d_Q = nn.Dense(features=self.head_size, use_bias=False)
        self.d_K = nn.Dense(features=self.head_size, use_bias=False)
        self.d_V = nn.Dense(features=self.head_size, use_bias=False)
        self.d_O = nn.Dense(features=self.d_model, use_bias=False)
        self.dropout = nn.Dropout(self.dropout_rate)
        
    def __call__(self, x, training=True):
            B,T,C = x.shape
            query = self.d_Q(x)
            key = self.d_K(x)
            value = self.d_V(x)
            
            weights = jnp.matmul(query, key.transpose(0,2, 1)) * (key.shape[-1] ** -0.5)
            mask = jnp.tril(jnp.ones(T,T))
            weights = weights + (mask * -1e9)
            weights = nn.softmax(weights, axis=-1)
            out = jnp.matmul(weights, value)
            out = self.d_O(out)
            out = self.dropout(out, deterministic=not training)
            return out