In [1]:
import os
import re
import urllib.request
from importlib.metadata import version
from dataclasses import dataclass

try:
    import tiktoken
except ImportError:
    !pip install tiktoken
    import tiktoken

try:
    import torch
except ImportError:
    !pip install torch
    import torch

print("torch version:", torch.__version__)
print("tiktoken version:", version("tiktoken"))

torch version: 2.1.2+cpu
tiktoken version: 0.5.2


In [2]:
from utils import attentions

In [3]:
torch.manual_seed(123)

@dataclass
class Params:
    context_length = 1024
    d_in = 768
    d_out = 768
    num_heads = 12
    batch_size = 8
    dropout = 0.0
    qkv_bias=False
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    embeddings = torch.randn((batch_size, context_length, d_out), device=device)

    print(f"PyTorch version: {torch.__version__}")
    print(f"Running on {device}")

my_params = Params()

PyTorch version: 2.1.2+cpu
Running on cpu


In [4]:
my_params.embeddings.shape

torch.Size([8, 1024, 768])

## build Multi-Head Attn using several single head attns

In [5]:
import time

In [6]:
mha_wrapper = attentions.MultiHeadAttentionWrapper(
    d_in=my_params.d_in, 
    d_out=my_params.d_out//12,
    context_length=my_params.context_length,
    dropout=my_params.dropout,
    num_heads=my_params.num_heads,
    qkv_bias=my_params.qkv_bias
).to(my_params.device)

t1 = time.time()
out = mha_wrapper(my_params.embeddings)
t2 = time.time()
print(out.shape)
print(f"time cost >> {t2-t1:.4f}")

torch.Size([8, 1024, 768])
time cost >> 3.3361


## build MHA not using wrapper

In [7]:
mha = attentions.MultiHeadAttention(
    d_in=my_params.d_in, 
    d_out=my_params.d_out,
    context_length=my_params.context_length,
    dropout=my_params.dropout,
    num_heads=my_params.num_heads,
    qkv_bias=my_params.qkv_bias
).to(my_params.device)

t1 = time.time()
out = mha(my_params.embeddings)
t2 = time.time()
print(out.shape)
print(f"time cost >> {t2-t1:.4f}")

torch.Size([8, 1024, 768])
time cost >> 3.3689


## build MHA using single weight matrix

In [8]:
mha_combined_qkv = attentions.MultiHeadAttentionCombinedQKV(
    d_in=my_params.d_in, 
    d_out=my_params.d_out,
    context_length=my_params.context_length,
    dropout=my_params.dropout,
    num_heads=my_params.num_heads,
    qkv_bias=my_params.qkv_bias
).to(my_params.device)

t1 = time.time()
out = mha_combined_qkv(my_params.embeddings)
t2 = time.time()
print(out.shape)
print(f"time cost >> {t2-t1:.4f}")

torch.Size([8, 1024, 768])
time cost >> 4.8357


## build MHA using Pytorch Scaled Dot Product

In [9]:
mha_pytorch_scaled = attentions.MHAPyTorchScaledDotProduct(
    d_in=my_params.d_in, 
    d_out=my_params.d_out,
    context_length=my_params.context_length,
    dropout=my_params.dropout,
    num_heads=my_params.num_heads,
    qkv_bias=my_params.qkv_bias
).to(my_params.device)

t1 = time.time()
out = mha_pytorch_scaled(my_params.embeddings)
t2 = time.time()
print(out.shape)
print(f"time cost >> {t2-t1:.4f}")

torch.Size([8, 1024, 768])
time cost >> 2.2110


## build MHA using torch.nn.MultiHeadAttention

In [10]:
mha_pytorch_class_default = attentions.MHAPyTorchClass(
    d_in=my_params.d_in, 
    d_out=my_params.d_out,
    context_length=my_params.context_length,
    dropout=my_params.dropout,
    num_heads=my_params.num_heads,
    qkv_bias=my_params.qkv_bias,
    need_weights=True
).to(my_params.device)

t1 = time.time()
out = mha_pytorch_class_default(my_params.embeddings)
t2 = time.time()
print(out.shape)
print(f"time cost >> {t2-t1:.4f}")

torch.Size([8, 1024, 768])
time cost >> 2.9372


## Using PyTorch's torch.nn.MultiheadAttention with scaled_dot_product_attention
- need_weights: If specified, returns attn_output_weights in addition to attn_outputs. Set need_weights=False to use the optimized scaled_dot_product_attention and achieve the best performance for MHA. Default: True. 

In [11]:
mha_pytorch_class_noweights = attentions.MHAPyTorchClass(
    d_in=my_params.d_in, 
    d_out=my_params.d_out,
    context_length=my_params.context_length,
    dropout=my_params.dropout,
    num_heads=my_params.num_heads,
    qkv_bias=my_params.qkv_bias,
    need_weights=False # NEW!
).to(my_params.device)

t1 = time.time()
out = mha_pytorch_class_noweights(my_params.embeddings)
t2 = time.time()
print(out.shape)
print(f"time cost >> {t2-t1:.4f}")

torch.Size([8, 1024, 768])
time cost >> 2.8121
