|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 2:</h2>|<h1>Large language models<h1>|
|<h2>Section:</h2>|<h1>Build a GPT<h1>|
|<h2>Lecture:</h2>|<h1><b>CodeChallenge HELPER: Code Attention manually and in Pytorch<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">udemy.com/course/dulm_x/?couponCode=202509</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

import time

# Exercise 1: Simulate data and attention matrices

In [None]:
# parameters
n_batch = 4
n_embed = 10
context_length = 8
vocab_size = 40

# input data
data = torch.randint # [batch,tokens]

In [None]:
# embedding matrix
embeddings = nn.

# create the q,k,v matrices
key   = nn.
query = nn.
value = nn.,bias=False)

### Process the data

In [None]:
# tokens to embeddings
x = embeddings

# weight the data pre-attention
k = key
q =
v =

In [None]:
# print data sizes
print(f'      Data matrix:
print(f'Embeddings matrix:
print(f' Token embeddings:

# sizes of matrices
print('')


# print attention matrices sizes
print('')


# Exercise 2: Implement self-attention

In [None]:
### manual implementation

# "cosine similarity" between query and keys (note: would actually be cosine similarity if scaled by |q||k| )
qk =  # transpose non-batch dimensions

# variance-scale the QK
qk_scaled =

# apply mask for future tokens
pastmask = torch.tril(torch.ones
qk_scaled[pastmask==0] =  # equivalent to adding a matrix of zeros/-infs

# softmaxify
qk_softmax = F.softmax(

# and final attention mechanism
actsManual =

print(f'Shape of activations (manual): {actsManual.shape}') # [batch, context, n_embed]

In [None]:
# pytorch implementation
actsTorch =
print(f'Shape of activations (PyTorch): {actsTorch.shape}')

In [None]:
# compare
print(actsManual[0,:,:])
print('')
print(actsTorch[0,:,:])
print('')
print(actsManual[0,:,:]-actsTorch[0,:,:])

print(f'\n\nAre they _exactly_ equal? {torch.equal
print(f'Are they "equal"? {torch.

# Exercise 3: CPU computation time

In [None]:
numReps = 50_000

# the manual version
start_time = time.time()
for _ in range(numReps):
  qk = q@k.transpose(-2,-1) * (n_embed**-.5)
  pastmask = torch.tril(torch.ones(n_batch,context_length,context_length))
  qk[pastmask==0] = -torch.inf
  qk = F.softmax(qk,dim=-1)
  activations = qk @ v
print(f'---    Manual: {time.time()-start_time:.3f} sec')

# the optimized version

print(f'--- Optimized: {time.time()-start_time:.3f} sec')

# Exercise 4: GPU computation time

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

### Using bigger matrices

In [None]:
# parameters
n_batch = 64
n_embed = 1000
context_length = 2048
vocab_size = 50257

# create matrices
data =
embedding =
key   =
query =
value =

x =
k =
q =
v =

### Now for the test!

In [None]:
numReps = 200

torch.cuda.synchronize() # synchronize the GPU&CPU. good for time-testing, bad for overall performance
start_time = time.time()
for _ in range(numReps):
  qk = q@k.transpose(-2,-1) * (n_embed**-.5)
  pastmask
  qk
  qk = # softmax
  activationsM = # activations
print(f'--- Manual:  {time.time()-start_time:.3f} sec')


torch.cuda.synchronize()
start_time = time.time()
for _ in range(numReps):

print(f'--- Pytorch: {time.time()-start_time:.3f} sec')

In [None]:
# some additional optimizations
import torch._dynamo
SDPA_compiled = torch.compile(F.scaled_dot_product_attention)
torch.set_float32_matmul_precision('high')

In [None]:
# FYI, FlashAttention: https://github.com/Dao-AILab/flash-attention

In [None]:
numReps = 200

torch.cuda.synchronize() # synchronize the GPU&CPU. good for time-testing, bad for overall performance
start_time = time.time()
for _ in range(numReps):
  qk = q@k.transpose(-2,-1) * (n_embed**-.5)
  pastmask = torch.tril(torch.ones(n_batch,context_length,context_length,device=device))
  qk[pastmask==0] = -torch.inf
  qk = F.softmax(qk,dim=-1)
  activationsM = qk @ v
print(f'--- Manual:  {time.time()-start_time:.3f} sec')


torch.cuda.synchronize()
start_time = time.time()
for _ in range(numReps):
  activationsP = F.scaled_dot_product_attention(q,k,v,is_causal=True)
print(f'--- Pytorch: {time.time()-start_time:.3f} sec')


torch.cuda.synchronize()
start_time = time.time()
for _ in range(numReps):
  activationsO = SDPA_compiled(q,k,v,is_causal=True)
print(f'--- Compiled: {time.time()-start_time:.3f} sec')