In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter, FuncFormatter

from scipy.special import erf

import torch
import torch.nn as nn
import torch.nn.functional as F

from time import time
from time import process_time

import math

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [3]:
# Set the key parametes
BATCH_SIZE = 4
embed_dim = 10
context_length = 8
vocab_size = 40

In [4]:
# Create a random dataset
data = torch.randint(low=0, high=vocab_size, size=(BATCH_SIZE, context_length))
data

tensor([[ 4, 10,  3,  0,  7, 15, 14,  3],
        [ 5, 29, 37, 22, 16, 30,  3, 20],
        [35, 18, 16, 10, 27, 33, 37,  2],
        [26, 18, 21, 20, 18, 33, 37, 24]])

In [5]:
# Create embeddings matrix
embeddings = nn.Embedding(vocab_size, embed_dim)

In [6]:
# Create Q, K, V matices
query = nn.Linear(embed_dim, embed_dim, bias=False)
key = nn.Linear(embed_dim, embed_dim, bias=False)
value = nn.Linear(embed_dim, embed_dim, bias=False)

### Process the data


In [7]:
X = embeddings(data)

In [8]:
# Create weights for the pre-tained model
q = query(X)
k = key(X)
v = value(X)

In [9]:
# Print data sizes
print(f'{'Data Matrix Shape':24}: {data.shape}')
print(f'{'Embeddings Matrix Shape':24}: {embeddings.weight.shape}')
print(f'{'Token matrix shape':24}: {X.shape}')

# sizes of matrices
print('')
print(f'{'Query Matrix Shape':24}: {query.weight.shape}')
print(f'{'Key Matrix Shape':24}: {key.weight.shape}')
print(f'{'Value Matrix Shape':24}: {value.weight.shape}')

# Attention matrix sizes
print('')
print(f'{'Q(x)':24}: {q.shape}')
print(f'{'K(x)':24}: {k.shape}')
print(f'{'V(x)':24}: {v.shape}')

Data Matrix Shape       : torch.Size([4, 8])
Embeddings Matrix Shape : torch.Size([40, 10])
Token matrix shape      : torch.Size([4, 8, 10])

Query Matrix Shape      : torch.Size([10, 10])
Key Matrix Shape        : torch.Size([10, 10])
Value Matrix Shape      : torch.Size([10, 10])

Q(x)                    : torch.Size([4, 8, 10])
K(x)                    : torch.Size([4, 8, 10])
V(x)                    : torch.Size([4, 8, 10])


# Create Attention activation
$$
\operatorname{Attention}(Q, K, V)
=
\operatorname{softmax}\!\left(
\frac{QK^{\top}}{\sqrt{d_k}} + M
\right) V
$$

Where,  
A = Attention function  
Q = Query Matrix  
K = Key Matrix  
$\\d_k$ = dimension of Key Matrix  
M = Masking
V = Value Matrix


## *Create attention manually*

In [10]:
# Create causal mask manually - objective is to set the future values in the context window to be zero after softmax
mask = torch.tril(torch.ones(BATCH_SIZE, context_length, context_length))

# Create the Q*K.T and scale it by square-root of embedding matrix
qk = q @ k.transpose(-2,-1)
qk_scaled = qk / (embed_dim ** 0.5)

# Mnaully set the non-diagonal elements in lower triagular matrix to -inf
qk_scaled[mask == 0] = -torch.inf

# Apply softmax
qk_softmaxed = F.softmax(qk_scaled, dim=-1)

# Multiply with value matrix
attention_manual = qk_softmaxed @ v

## *Create attention using PyTorch*

In [11]:
attention_torch = F.scaled_dot_product_attention(query=q,
                                           key=k,
                                           value=v,
                                           is_causal=True
)

torch.allclose(attention_manual, attention_torch, atol=1e-5, rtol=1e-5)

True

## *Calculate the processing times for both manual and pytorch methods on CPU*

In [12]:
num_runs = 50_000
start_time = process_time()

for _ in range(num_runs):
  mask = torch.tril(torch.ones(BATCH_SIZE, context_length, context_length))
  qk = q @ k.transpose(-2,-1)
  qk_scaled = qk / (embed_dim ** 0.5)
  qk_scaled[mask == 0] = -torch.inf
  qk_softmaxed = F.softmax(qk_scaled, dim=-1)
  attention_manual = qk_softmaxed @ v

elapsed_time = process_time() - start_time

print(f'Time taken for manual calculation: {elapsed_time:.3f} seconds')


Time taken for manual calculation: 43.629 seconds


In [13]:
%%timeit
for _ in range(num_runs):
  attention_torch = F.scaled_dot_product_attention(query=q,
                                           key=k,
                                           value=v,
                                           is_causal=True
)

6.57 s ± 59.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Run the entire code on GPU

In [14]:
# Set the key parametes
BATCH_SIZE = 64
embed_dim = 1000
context_length = 2048
vocab_size = 50_257

In [15]:
# Create a random dataset
data = torch.randint(low=0, high=vocab_size, size=(BATCH_SIZE, context_length), dtype=torch.long, device=device)

# Create embeddings matrix
embeddings = nn.Embedding(vocab_size, embed_dim, device=device)

#Create Q, K, V matices
query = nn.Linear(embed_dim, embed_dim, bias=False, device=device)
key = nn.Linear(embed_dim, embed_dim, bias=False, device=device)
value = nn.Linear(embed_dim, embed_dim, bias=False, device=device)


In [16]:
X = embeddings(data)

# Create weights for the pre-tained model
q = query(X)
k = key(X)
v = value(X)

In [17]:
num_runs = 200
start_time = process_time()

for _ in range(num_runs):
  mask = torch.tril(torch.ones(BATCH_SIZE, context_length, context_length))
  qk = q @ k.transpose(-2,-1)
  qk_scaled = qk / (embed_dim ** 0.5)
  qk_scaled[mask == 0] = -torch.inf
  qk_softmaxed = F.softmax(qk_scaled, dim=-1)
  attention_manual = qk_softmaxed @ v

elapsed_time = process_time() - start_time

print(f'Time taken for manual calculation on GPU: {elapsed_time:.3f} seconds')


Time taken for manual calculation on GPU: 634.034 seconds


In [18]:
start_time = process_time()

for _ in range(num_runs):
  attention_torch = F.scaled_dot_product_attention(query=q,
                                           key=k,
                                           value=v,
                                           is_causal=True
)

elapsed_time = process_time() - start_time

print(f'Time taken for PyTorch calculation on GPU: {elapsed_time:.3f} seconds')

Time taken for PyTorch calculation on GPU: 8.874 seconds


In [21]:
# Further optimizations on GPU
import torch._dynamo as dynamo
SDPA_compiled = torch.compile(F.scaled_dot_product_attention)

# Apply recommended TF32 precision settings
torch.backends.cuda.matmul.fp32_precision = "tf32"
torch.backends.cudnn.conv.fp32_precision = "tf32"

In [22]:
num_runs = 200

torch.cuda.empty_cache()
torch.cuda.synchronize()

# Manual Calculations
start_time = process_time()

for _ in range(num_runs):
  mask = torch.tril(torch.ones(BATCH_SIZE, context_length, context_length))
  qk = q @ k.transpose(-2,-1)
  qk_scaled = qk / (embed_dim ** 0.5)
  qk_scaled[mask == 0] = -torch.inf
  qk_softmaxed = F.softmax(qk_scaled, dim=-1)
  attention_manual = qk_softmaxed @ v

elapsed_time = process_time() - start_time

print(f'Time taken for manual calculation on GPU: {elapsed_time:.3f} seconds')

# Using PyTorch
start_time = process_time()

for _ in range(num_runs):
  attention_torch = F.scaled_dot_product_attention(query=q,
                                           key=k,
                                           value=v,
                                           is_causal=True
)

elapsed_time = process_time() - start_time

print(f'Time taken for PyTorch calculation on GPU: {elapsed_time:.3f} seconds')

# Using Optimizations
start_time = process_time()

for _ in range(num_runs):
  attention_torch = SDPA_compiled(query=q,
                                  key=k,
                                  value=v,
                                  is_causal=True
)

elapsed_time = process_time() - start_time

print(f'Optimized Time taken for PyTorch calculation on GPU: {elapsed_time:.3f} seconds')


Time taken for manual calculation on GPU: 633.002 seconds
Time taken for PyTorch calculation on GPU: 2.743 seconds
Optimized Time taken for PyTorch calculation on GPU: 1.263 seconds
