In [78]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import mlflow as mlflow

# Hyper Parameters
batch_size = 64
block_size = 128
n_embd = 32
n_head = 4
nvocab = 65
bias = True
dropout = 0.2
import math

In [54]:
c = nn.Linear(n_embed,3*n_embed)

In [23]:
embedding_table = nn.Embedding(nvocab,n_embed)
positional_embedding = nn.Embedding(block_size,n_embed)

In [37]:
d = embedding_table(torch.randint(65,(batch_size,block_size)))+positional_embedding(torch.arange(block_size))

In [38]:
d.shape

torch.Size([64, 128, 32])

In [20]:
torch.randint(1000 - block_size ,(batch_size,)).shape

torch.Size([64])

In [53]:
torch.matmul(c,d)

RuntimeError: The size of tensor a (128) must match the size of tensor b (64) at non-singleton dimension 1

In [None]:

c d.shape

(torch.Size([64, 128, 4, 8]), torch.Size([64, 128, 32]))

In [57]:
e = c(d)

In [71]:
k,q,v = e.split(32,dim=-1)

In [70]:
k.view(64,4,128,8)

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [82]:
k.view(64,128,4,8).transpose(1,2).shape

torch.Size([64, 4, 128, 8])

In [79]:
class CausalSelfAttention(nn.Module):

    def __init__(self):
        super().__init__()
        assert n_embd % n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=bias)
        # output projection
        self.c_proj = nn.Linear(n_embd, n_embd, bias=bias)
        # regularization
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        self.n_head = n_head
        self.n_embd = n_embd
        self.dropout = dropout
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(block_size, block_size))
                                        .view(1, 1, block_size, block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        print(k.shape,q.shape,v.shape)
        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y


In [80]:
ca = CausalSelfAttention()

In [81]:
ca(d)

torch.Size([64, 4, 128, 8]) torch.Size([64, 4, 128, 8]) torch.Size([64, 4, 128, 8])


tensor([[[-0.0000e+00,  8.4630e-01,  9.1909e-01,  ..., -1.1952e+00,
           1.0077e+00,  2.9027e-01],
         [-7.6402e-01,  7.9218e-01,  0.0000e+00,  ..., -7.4257e-01,
           1.6362e-01,  0.0000e+00],
         [ 7.0273e-02,  2.1619e-01, -3.8167e-01,  ..., -4.2901e-01,
          -0.0000e+00, -1.9627e-01],
         ...,
         [ 3.3033e-01, -4.8015e-02,  2.9802e-01,  ..., -1.9262e-01,
           7.8003e-02,  1.6547e-02],
         [ 2.6780e-01, -1.0062e-01,  1.7785e-01,  ..., -1.8446e-01,
           1.4190e-01, -1.6145e-02],
         [ 3.4581e-02,  0.0000e+00,  2.5374e-01,  ..., -0.0000e+00,
           2.3404e-01,  1.4306e-02]],

        [[-1.2588e-01,  1.2317e-03,  6.9245e-01,  ..., -2.9219e-02,
           1.4338e-01,  3.1266e-01],
         [-2.0354e-01,  2.3665e-01,  7.1682e-01,  ..., -1.0134e-01,
           6.6756e-02, -3.0114e-01],
         [-2.7937e-02,  1.2498e-01, -4.2154e-02,  ..., -1.0832e-01,
           1.6853e-01,  5.1448e-01],
         ...,
         [ 0.0000e+00,  6