Using (B, ) matrix shape convention.

In [96]:
import numpy as np
import math

In [114]:
class LinearTransformation:
    """
    A learned linear transformation y = x * W^T + b

    in_size is the length of the vector input x.
    out_size is the length of the vector output y.
    """
    def __init__(self, in_size: int, out_size: int):
        self.in_size = in_size
        self.out_size = out_size

        # Cached variables.
        self.x = None
        self.W = np.random.randn(self.in_size, self.out_size)
        self.b = np.random.randn(self.out_size)
        self.dW = np.zeros_like(self.W)
        self.db = np.zeros_like(self.b)

    def __call__(self, x):
        # Ensure x is flat:
        assert len(x.shape) <=2, f"x must have shape (B, dim1),\
                        not {x.shape}"
        self.x = x
        self.y = np.matmul(self.x, self.W) + self.b
        return self.y
    
    def backward(self, d_out):
        # Find ∂L/∂W, ∂L/∂b and ∂L/∂x
        # Find derivative w.r.t W.
        self.dW = np.matmul(self.x.T, d_out)
        self.db = np.sum(d_out, axis=0)
        dx = np.matmul(d_out, self.W.T)
        return dx
    
    def update(self, lr):
        self.W = self.W - self.dW * lr
        self.b = self.b - self.db * lr

In [115]:
# Unit test: LinearTransformation.linear_transformation(x).


In [142]:
def softmax(x): #https://stackoverflow.com/questions/34968722/how-to-implement-the-softmax-function-in-python
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / (e_x.sum(axis=0) + 1e-5) # For numerical stability.

In [146]:
class MultiHeadAttention:
    def __init__(self, embed_dim:int, num_heads=8):
        assert embed_dim % num_heads == 0, \
            "embed_dim must be divisible by num_heads"

        self.embed_dim = embed_dim
        # 8 heads in 'Attention is all you need'.
        self.num_heads = num_heads 

        # Linear transformations.
        self.linearQ = LinearTransformation(in_size=self.embed_dim,
                                            out_size=self.embed_dim)
        self.linearK = LinearTransformation(in_size=self.embed_dim,
                                            out_size=self.embed_dim)
        self.linearV = LinearTransformation(in_size=self.embed_dim,
                                            out_size=self.embed_dim)
        self.linearS = LinearTransformation(in_size=embed_dim \
                                            * num_heads,
                                            out_size=embed_dim)
        
        # Cache for backpropagation.
        self.Q = None
        self.K = None
        self.V = None
        self.Qh = None
        self.Kh = None
        self.Vh = None
        self.scores = None
        self.attention_weights = None
        self.context_per_head = None
        self.context_merge = None
        self.output = None

        # Weights.
        self.dQ = None
        self.dK = None
        self.dV = None


        
    def __call__(self, x):
        # Step 1: Calculate LTs for Q, K, V.
        self.Q = self.linearQ(x)
        self.K = self.linearK(x)
        self.V = self.linearV(x)

        # Step 2: break dim -1 into dim -2 and -1 by 
        # dim -2 = dim -1//num_heads.
        # i.e from (batch, dim-2, dim-1) into (batch, dim-3, dim-2, dim-1).
        Q_reshape = self.Q.reshape(
            self.Q.shape[0], 
            self.Q.shape[1],
            self.num_heads,
            self.Q.shape[2] // self.num_heads
        )
        K_reshape = self.K.reshape(
            self.K.shape[0], 
            self.K.shape[1],
            self.num_heads,
            self.K.shape[2] // self.num_heads
        )
        V_reshape = self.V.reshape(
            self.V.shape[0], 
            self.V.shape[1],
            self.num_heads,
            self.V.shape[2] // self.num_heads
        )

        # Step 3: permute the dimensions from:
        # (batch, dim1, num_heads, dim3) to 
        # (batch, num_heads, dim1, dim3).
        self.Qh = Q_reshape.transpose(0, 2, 1, 3)
        self.Kh = K_reshape.transpose(0, 2, 1, 3)
        self.Vh = V_reshape.transpose(0, 2, 1, 3) 

        # Step 4: Scaled Dot Product Attention.
        # Step 4.1: permute K dims.
        K_perm = np.swapaxes(self.Kh, -2, -1)

        # Step 4.2: matrix multiply Q and K_transpose.
        QK = np.matmul(self.Qh, K_perm)

        # Step 4.3: scale prior product with the number of elements in K.
        self.scores = QK / math.sqrt(self.Qh.shape[-1] * self.Qh.shape[-2])

        # Step 4.4: mask (optional).

        # Step 4.5: softmax.
        self.attention_weights = softmax(self.scores, axis=-1)

        # Step 4.6: matrix multiply prior product and V.
        self.context_per_head = np.matmul(self.attention_weights, self.Vh)
        
        # Step 5: permute the num_heads back to dim 2.
        context_transpose = self.context_per_head\
                                .transpose(0, 2, 1, 3)

        # Step 6: merge dimensions -1 and -2.
        self.context_merge = context_transpose.reshape(
            context_transpose.shape[0],
            context_transpose.shape[1],
            int(context_transpose.shape[2] *\
                     context_transpose.shape[3])
        )
        
        # Step 7: linear transformation
        self.output = self.linearS(self.context_merge)
        return self.output
        

    def backward(self, d_out):
        # Work out gradients ∂L/∂Q, ∂L/∂K, ∂L/∂V and dx.
        B, T, E = self.x.shape

        # Backprop through final linear projection.
        d_concat = d_out @ self.W_o.T # (B, T, E)
        self.dW_o = self.combine_heads(self.attn_output).reshape(B * T, E).T @ d_out.reshape(B * T, E)

        # Split heads
        d_attn_output = self.split_heads(d_concat)           # (B, H, T, D)

        # d_attn_weights @ V => d_attn_weights, d_V_heads
        d_attn_weights = d_attn_output @ self.V_heads.transpose(0,1,3,2)
        d_V_heads = self.attn_weights.transpose(0,1,3,2) @ d_attn_output

        # Softmax backward
        # dS = softmax * (d - sum(softmax * d))
        softmax = self.attn_weights
        d_scores = d_attn_weights * softmax - softmax * np.sum(d_attn_weights * softmax, axis=-1, keepdims=True)

        d_scores /= np.sqrt(self.head_dim)

        # Q @ K^T => dQ_heads, dK_heads
        d_Q_heads = d_scores @ self.K_heads
        d_K_heads = d_scores.transpose(0,1,3,2) @ self.Q_heads

        # Merge heads
        d_Q = self.combine_heads(d_Q_heads)                  # (B, T, E)
        d_K = self.combine_heads(d_K_heads)
        d_V = self.combine_heads(d_V_heads)

        # Backprop through input projections
        self.dW_q = self.x.reshape(B*T, E).T @ d_Q.reshape(B*T, E)
        self.dW_k = self.x.reshape(B*T, E).T @ d_K.reshape(B*T, E)
        self.dW_v = self.x.reshape(B*T, E).T @ d_V.reshape(B*T, E)

        # Chain rule to previous layer
        dx_q = d_Q @ self.W_q.T
        dx_k = d_K @ self.W_k.T
        dx_v = d_V @ self.W_v.T
        dx = dx_q + dx_k + dx_v

        return dx

    def update(self):
        self.linearQ.update(lr=)
        pass

In [147]:
# Unit test: MultiHeadAttention.calculate_attention(Q, K, V).
UTQ, UTK, UTV = np.random.rand(1, 8, 8), np.random.rand(1, 8, 8), np.random.rand(1, 8, 8)
UT_calc_att = MultiHeadAttention(embed_dim=512)
#UT_att = UT_calc_att.calculate_attention(UTQ, UTK, UTV)
#print(UT_att)

In [148]:
class Tokenizer:
    def __init__(self, vocab_length, embed_dim):
        self.vocab_length = vocab_length
        self.embed_dim = embed_dim
        self.E = np.random.randn(vocab_length, embed_dim)

    def positional_encoding(self, seq_length: int, embed_dim: int):
        """
        Generates a positional encoding for a given length and depth.

        Args:
        - length: the length of the input sequence.
        - embed_dim: the dimensionality of the encoding.

        Returns:
        - np.array, Positional encoding of shape (length, embed_dim).
        
        """
        embed_dim /= 2

        positions = np.arange(seq_length)[:, np.newaxis]
        depths = np.arange(embed_dim)[np.newaxis, :] / embed_dim

        angle_rates = 1 / (10000**depths)
        angle_rads = positions * angle_rates

        pos_encoding = np.concatenate([
                                    np.sin(angle_rads), 
                                    np.cos(angle_rads)], 
                                    axis=-1)
        
        return pos_encoding
    
    def embed_tokens(self, tokens, seq_length, embed_dim):
        embed = self.E[tokens]
        return embed + self.positional_encoding(seq_length, embed_dim)



In [149]:
class LayerNorm:
    def __init__(self, embed_dim, eps=1e-5):
        self.embed_dim = embed_dim
        self.eps = eps # Numerical stability.

        # Initialize scale (gamma) and shift (beta).
        self.gamma = np.ones((self.embed_dim, )) 
        self.beta = np.zeros((self.embed_dim, )) 

        # Initialize gradients.
        self.dgamma = np.zeros_like(self.gamma)
        self.dbeta = np.zeros_like(self.beta)

        # Cache.
        self.x = None
        self.mean = None
        self.var = None
        self.std = None
        self.x_norm = None
        
    def __call__(self, x):
        self.x = x      # Save for backpropagation.
        self.mean = np.mean(x, axis=-1, keepdims=True)
        self.var = np.var(x, axis=-1, keepdims=True)
        self.std = np.sqrt(self.var + self.eps)

        self.x_norm = (x - self.mean) / self.std

        return self.gamma[None, None, :] * self.x_norm \
                + self.beta[None, None, :]
    
    def backward(self, d_out):
        # Compute ∂L/∂x, ∂L/∂gamma and ∂L/∂beta.
        pass



In [150]:
def gelu(x):
    # From PyTorch documenation.
    return 0.5 * x * (1 + np.tanh(\
        np.sqrt(2 / np.pi) \
        * x + 0.044715 * x ** 3))

In [151]:
class GPT2:
    """
    Defines the overarching process, including tokenization.
    """
    def __init__(self, vocab_size, embed_dim, max_seq_len, batch, paper_dim=3072):
        # Define all of the components which make up the transformer.
        # Embedding for words.
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.embedding = Tokenizer(self.vocab_size, 
                                   self.embed_dim)
        
        self.max_seq_len = max_seq_len  # As each sample is padded.
        self.batch = batch
        self.dims = (self.batch, max_seq_len, self.embed_dim)

        # Layer Norm.
        self.layernorm1 = LayerNorm(embed_dim=self.embed_dim)
        self.linear1 = LinearTransformation(in_size=self.embed_dim,
                                            out_size=self.embed_dim)

        # Multihead attention, Q, K, V are incorporated into this.
        self.q_proj = LinearTransformation(in_size=self.embed_dim,
                                           out_size=self.embed_dim)
        self.k_proj = LinearTransformation(in_size=self.embed_dim,
                                           out_size=self.embed_dim)
        self.v_proj = LinearTransformation(in_size=self.embed_dim,
                                           out_size=self.embed_dim)
        self.multihead_attention = MultiHeadAttention\
                                    (embed_dim=self.embed_dim)
        # Produces a matrix (Batch, max_seq_len, embed_dim).
        
        self.layernorm2 = LayerNorm(embed_dim=self.embed_dim)

        # Feed forward is two linear transformation layers
        # with gelu activation.
        self.paper_dim = paper_dim # Richer context dimension <- 3072.
        self.linearff1 = LinearTransformation(in_size=self.embed_dim,
                                              out_size=self.paper_dim)
        self.linearff2 = LinearTransformation(in_size=self.paper_dim,
                                              out_size=self.embed_dim)

        self.layernorm3 = LayerNorm(embed_dim=self.embed_dim)
        self.linear3 = LinearTransformation(in_size=self.embed_dim,
                                            out_size=self.embed_dim)


    def __call__(self, batch):
        # Embed the entire input
        # tokenized_input = self.tokenizer.\
        #                     embed_tokens(tokens= ,
        #                                 seq_length= ,
        #                                 embed_dim=)
        
        # Feed the tokens into block1 and save for residual connection.
        tokenized_input = batch
        ln1_out = self.layernorm1(tokenized_input)
        lin1_out = self.linear1(ln1_out)

        # Multihead attention.
        Q = self.q_proj(lin1_out)
        K = self.k_proj(lin1_out) 
        V = self.v_proj(lin1_out)
        mha_out = self.multihead_attention(Q, K, V)

        # Residual connection 1.
        residual = mha_out + tokenized_input

        # Layer Norm and FFN.
        ln2_out = self.layernorm2(residual)
        ff1_out = self.linearff1(ln2_out)
        ff2_out = self.linearff2(ff1_out)
        
        # Concatenate the residual with ff_out.
        concat = ff2_out + residual

        # Final block and return softmax for probabilties.
        ln3_out = self.layernorm3(concat)
        lin3_out = self.linear3(ln3_out)
        logits = softmax(lin3_out)

        return logits
            

    def backward(self):
        # Do the calculation for each loss for each layer, backwards.
        

    
    

In [152]:
# Unit test: GPT2.__call__(batch).
# Random input.
B, seq_len, embed_dim = 4, 10, 768
x = np.random.randint(1, 768, (B, seq_len, embed_dim))
model = GPT2(vocab_size=10000,
             embed_dim=embed_dim,
             max_seq_len=seq_len,
             batch=B)
output = model(x)
print(output)

[[[1.91349995e-27 5.81111069e-13 5.73926770e-22 ... 1.19570838e-27
   9.58005228e-15 2.66700377e-16]
  [5.44747373e-40 3.21771446e-28 7.98375206e-30 ... 2.06267159e-17
   1.62393800e-18 1.11615097e-24]
  [1.31087789e-16 2.83409528e-23 3.67691288e-12 ... 4.87758839e-18
   1.76014385e-37 9.28408266e-35]
  ...
  [5.30493913e-21 2.27255554e-26 7.29809773e-07 ... 1.41928656e-29
   9.82121906e-27 1.20477375e-30]
  [3.65983447e-27 7.59923740e-26 1.17513713e-19 ... 1.12300108e-19
   1.98346432e-16 1.58426387e-35]
  [3.72829596e-32 8.09315423e-08 1.74512267e-32 ... 1.80276927e-23
   1.94020637e-25 1.18963248e-22]]

 [[6.17288596e-29 9.24793160e-16 2.39275596e-25 ... 8.33868400e-37
   7.25870891e-24 1.11443872e-12]
  [6.47713845e-14 2.26396379e-19 3.83992818e-23 ... 1.17170028e-19
   1.07900847e-26 1.01314615e-18]
  [2.87589966e-17 1.00618494e-28 3.21250249e-06 ... 6.69301236e-24
   3.37230500e-29 2.17756533e-30]
  ...
  [3.00140612e-21 7.17045790e-23 3.87637550e-25 ... 7.09166978e-17
   1.13065

In [153]:
# Load data in.

In [154]:
# Use PyTorch Loss function to save time.
from torch.nn.functional import cross_entropy
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):
    def __init__(self, ):
        pass

    def __len__(self):
        return None

    def __getitem__(self, i):
        pass

In [155]:
model = GPT2(
    vocab_size= ,
    embed_dim= ,
    max_seq_len= ,
    batch= ,
)

def train(model, data, hparams):
    model.train()

    for epoch in range(1):
        for i, batch in enumerate(dataloader):
            



    

SyntaxError: expected argument value expression (3494379259.py, line 2)