In [51]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from time import process_time

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

# Hyperparameters

In [3]:
n_vocab = 50257 # GPT-2 vocabulary size
embed_dim = 768 # GPT-2 embedding domension size
seq_len = 1024 # GPT-2 sequence length/ context window size
n_heads = 12 # GPT-2 number of heads
n_layers = 12 # GPT-2 number of layers/ number of transfomer blocks
batch_size = 8

# Create a Multihead attention class

In [11]:
class MultiHeadAttention(nn.Module):
  def __init__(self):
    super().__init__()
    self.n_heads = n_heads
    self.head_dim = embed_dim // n_heads

    self.QKV = nn.Linear(embed_dim, 3 * embed_dim, bias=True)

    # Linear mixing after attention
    self.W0 = nn.Linear(embed_dim, embed_dim, bias=True)

  def forward(self, x, track_sizes=False):

    B, T, E = x.shape # Batch, seq_length, embedding dimension

    qkv = self.QKV(x)
    q, k, v = torch.split(qkv, embed_dim, dim=-1) # split into separate matrices
    if track_sizes: print(f"1){' QKV shape:':>28} {qkv.shape}")
    if track_sizes: print(f"1a){' q shape:':>28} {q.shape}")

    # Reshape the q, k, v matrices into [Batch, n_heads, seq_length (T), head_dim] --> SDPA needs the tensor of shape [B, n_heads, T, head_dim]
    q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
    k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
    v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
    if track_sizes: print(f"2){' Q shape after reshape:':>28} {q.shape}")

    # SDPA and transpose the matrix again
    out = F.scaled_dot_product_attention(q, k, v, is_causal=True).transpose(1,2).reshape(B, T, E)
    if track_sizes: print(f"3){' Attention output shape:':>28} {out.shape}")

    # Pass it through the Linear mixing matrix
    out = self.W0(out)
    if track_sizes: print(f"4){' Final output shape:':>28} {out.shape}")

    return out

In [12]:
mha = MultiHeadAttention()
out = mha(torch.randn(batch_size, seq_len, embed_dim),track_sizes=True)

1)                  QKV shape: torch.Size([8, 1024, 2304])
1a)                    q shape: torch.Size([8, 1024, 768])
2)      Q shape after reshape: torch.Size([8, 12, 1024, 64])
3)     Attention output shape: torch.Size([8, 1024, 768])
4)         Final output shape: torch.Size([8, 1024, 768])


# Create the Tranformer Block

In [35]:
class TransformerBlock(nn.Module):
  def __init__(self):
    super().__init__()
    self.mha = MultiHeadAttention()
    self.layernorm1 = nn.LayerNorm(embed_dim)
    self.layernorm2 = nn.LayerNorm(embed_dim)
    self.ff = nn.Sequential(
        nn.Linear(embed_dim, 4 * embed_dim),
        nn.GELU(),
        nn.Linear(4 * embed_dim, embed_dim)
    )

  def forward(self, x):

    # Attention
    out_attn = x + self.mha(self.layernorm1(x)) # pre-attention normalization + attention

    # MLP layer
    out = out_attn + self.ff(self.layernorm2(out_attn)) # post-attention normalization + MLP

    return out


In [36]:
transform = TransformerBlock()
out = transform(torch.randn(batch_size, seq_len, embed_dim))

out.shape

torch.Size([8, 1024, 768])

# Create the entire model

In [44]:
class LLModel(nn.Module):
  def __init__(self, device):
    super().__init__()

    # token + position embeddings
    self.wte = nn.Embedding(n_vocab, embed_dim) # token embedding
    self.wpe = nn.Embedding(seq_len, embed_dim) # position embedding

    # transformer blocks
    self.blocks = nn.Sequential(*[TransformerBlock() for _ in range(n_layers)])

    # final layernorm
    self.layernorm = nn.LayerNorm(embed_dim)

    # final head
    self.head = nn.Linear(embed_dim, n_vocab)
    self.head.weight = nn.Parameter(self.wte.weight) # share weights

    self.device = device

  def forward(self, x):
    token_emb = self.wte(x) # [B,T,E]
    position_embed = self.wpe(torch.arange(x.shape[-1], device=self.device)) #[T,E]
    out = token_emb + position_embed #[T,E]

    # Pass through transformer blocks
    out = self.blocks(out)

    # Pass through final layernorm
    out = self.layernorm(out)

    # Pass through final head/ MLP
    out = self.head(out) # [B,T,n_vocab]
    return out

  def generate(self, tokx, temperature=1., max_new_tokens=50):
    for _ in range(max_new_tokens):

      # forward pass
      logits = self(tokx[:, -seq_len:]) # [B,n_vocab]

      # Apply temeprature and softmax
      probs = F.softmax(logits / temperature, dim=-1) # [B, n_vocab]

      # Next sample token
      tokx_next = torch.multinomial(probs, num_samples=1)

      # Add to the sequence
      tokx = torch.cat((tokx, tokx_next), dim=1)
    return tokx

In [46]:
model = LLModel(device)
model


LLModel(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (blocks): Sequential(
    (0): TransformerBlock(
      (mha): MultiHeadAttention(
        (QKV): Linear(in_features=768, out_features=2304, bias=True)
        (W0): Linear(in_features=768, out_features=768, bias=True)
      )
      (layernorm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layernorm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (ff): Sequential(
        (0): Linear(in_features=768, out_features=3072, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=3072, out_features=768, bias=True)
      )
    )
    (1): TransformerBlock(
      (mha): MultiHeadAttention(
        (QKV): Linear(in_features=768, out_features=2304, bias=True)
        (W0): Linear(in_features=768, out_features=768, bias=True)
      )
      (layernorm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layernorm2): LayerNorm((768,), eps=1e-05, elementwise_affin

In [47]:
data = torch.randint(0, n_vocab, (batch_size, seq_len)).to(device)
data.shape

torch.Size([8, 1024])

In [48]:
model.to(device)
out = model(data)
out.shape

torch.Size([8, 1024, 50257])

In [49]:
!pip install torchinfo
from torchinfo import summary



In [50]:
summary(model, input_size=(batch_size, seq_len), dtypes=[torch.long], col_names=['input_size','output_size','num_params','trainable'], row_settings=['var_names'])

Layer (type (var_name))                  Input Shape               Output Shape              Param #                   Trainable
LLModel (LLModel)                        [8, 1024]                 [8, 1024, 50257]          --                        True
├─Embedding (wte)                        [8, 1024]                 [8, 1024, 768]            38,597,376                True
├─Embedding (wpe)                        [1024]                    [1024, 768]               786,432                   True
├─Sequential (blocks)                    [8, 1024, 768]            [8, 1024, 768]            --                        True
│    └─TransformerBlock (0)              [8, 1024, 768]            [8, 1024, 768]            --                        True
│    │    └─LayerNorm (layernorm1)       [8, 1024, 768]            [8, 1024, 768]            1,536                     True
│    │    └─MultiHeadAttention (mha)     [8, 1024, 768]            [8, 1024, 768]            2,362,368                 True
│  

In [56]:
# Instantiate once on CPU and then on GPU and run forward pass 5 times on a sample data

num_runs = 5
data = torch.randint(0, n_vocab, (batch_size, seq_len))

# Run on CPU
start_time = process_time()
for _ in range(num_runs):
  model = LLModel('cpu')
  out = model(data)
print(f'Elapsed time CPU: {process_time()-start_time:,.3f} sec')

# Run on GPU
start_time = process_time()
for _ in range(num_runs):
  model = LLModel(device).to(device)
  data = data.to(device)
  out = model(data)
print(f'Elapsed time GPU: {process_time()-start_time:,.3f} sec')

Elapsed time CPU: 126.069 sec
Elapsed time GPU: 5.911 sec


In [68]:
def loss_optim(model, data, device):
  # Define loss function and optimizer
  loss_func = nn.NLLLoss().to(device)
  optimizer = torch.optim.AdamW(model.parameters(),lr=0.001)
  data = data.to(device)
  model = LLModel(device=device).to(device=device)

  # Forward pass
  out = model(data)

  # Calculate loss
  loss = loss_func(out.view(-1, n_vocab), data.view(-1))

  # back-propagation to compute the gradients
  model.zero_grad() # Initialize the gradients
  loss.backward() # back-propagation of loss function to compute the gradients over the loss function
  optimizer.step() # Re-compute the trainable weight parameters

  return loss

In [72]:
# model = LLModel('cpu')
data = torch.randint(0, n_vocab, (batch_size, seq_len))

%timeit loss_optim(model, data,'cpu')
%timeit loss_optim(model, data, device)

12.8 s ± 96.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.49 s ± 14.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
