<a href="https://colab.research.google.com/github/QasimWani/simple-transformer/blob/main/transformers/tensor_puzzle.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# 10 Exercises to improve your working knowledge of einsum and transformers
# Easy - come up with answer in 1-2 minute
# Medium - come up with answer in 3 minutes
# Hard - come up with answer in 4 minutes

import torch
from einops import rearrange
import numpy as np

In [2]:
# Simple profiler
import functools
import time

def profile(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        start = time.perf_counter()
        result = func(*args, **kwargs)
        end = time.perf_counter()
        print(f"{func.__name__} took {end - start:.6f} seconds")
        return result
    return wrapper

In [3]:
# Problem #1 - Sequence dot product (Easy).
# Description: Calculate the dot product between two 1-d vectors

@profile
def seq_dot_spec(a, b, out):
    out[0] = 0
    for i in range(len(a)):
        out[0] += a[i] * b[i]

@profile
def seq_dot(a, b):
    """Your solution here"""
    # doing 'a,b->' would be multiply each value of a with each value of b producing a matrix of axb.
    # then taking the sum would produce a very high valuer
    return torch.einsum('i, i ->', a, b) # i,i takes element-wise sum.

# Test
a = torch.randn(1_000)
b = torch.randn(1_000)
out_spec = torch.zeros(1)
seq_dot_spec(a, b, out_spec)
your_out = seq_dot(a, b)
assert torch.allclose(your_out, out_spec, atol=1e-5)

seq_dot_spec took 0.025067 seconds
seq_dot took 0.014832 seconds


In [4]:
# Problem #2 - Batched Self-Attention Score Computation (Easy)
# Description: Calculate self-attention

@profile
def batched_attn_scores_spec(q, k, out):
    for b in range(q.shape[0]):
        for i in range(q.shape[1]):
            for j in range(k.shape[1]):
                dot = 0
                for d in range(q.shape[2]):
                    dot += q[b][i][d] * k[b][j][d]
                out[b][i][j] = dot
@profile
def batched_attn_scores(q, k):
    """Your solution here"""
    B, T, D = q.shape
    return torch.einsum('BTD,BKD->BTK', q, k)

# Test
q = torch.randn(4, 5, 8)
k = torch.randn(4, 7, 8)
out_spec = torch.zeros(4, 5, 7)
batched_attn_scores_spec(q, k, out_spec)
your_out = batched_attn_scores(q, k)
assert torch.allclose(your_out, out_spec, atol=1e-5)

batched_attn_scores_spec took 0.043477 seconds
batched_attn_scores took 0.000252 seconds


In [5]:
# Problem #3: Weighted sum of Values, i.e. Attention output (Easy)
# Description: Compute the weighted sum of value vectors given attention weights

@profile
def weighted_sum_spec(weights, values, out):
    for i in range(weights.shape[0]):
        for d in range(values.shape[1]):
            total = 0
            for j in range(weights.shape[1]):
                total += weights[i][j] * values[j][d]
            out[i][d] = total

@profile
def weighted_sum(weights, values):
    """Your solution here"""
    return torch.einsum('ij,jd->id', weights, values)

# Test
weights = torch.randn(5, 7)
values = torch.randn(7, 8)
out_spec = torch.zeros(5, 8)
weighted_sum_spec(weights, values, out_spec)
your_out = weighted_sum(weights, values)
assert torch.allclose(your_out, out_spec, atol=1e-5)


weighted_sum_spec took 0.007693 seconds
weighted_sum took 0.000128 seconds


In [6]:
# Problem #4 - Learned Positional Embedding Addition (Easy)
# Description: Add learned positional encodings to token embeddings using position indices

@profile
def add_positional_spec(tokens, positions, pos_embed_weight):
    result = tokens.clone()
    batch_size, seq_len = positions.shape
    for b in range(batch_size):
        for t in range(seq_len):
            pos_idx = positions[b, t]
            result[b, t] = tokens[b, t] + pos_embed_weight[pos_idx]
    return result

@profile
def add_positional(tokens, positions, pos_embed_weight):
    """Your solution here"""
    # tokens: batch_size, seq_len, d_embed
    # positions: batch_size, seq_len
    # pos_embed_weight = (batch_size x seq_len), d_embed

    # Goal: result[b, t] = tokens[b, t] + pos_embed_weigh[ positions[b, t] ]
    # return tokens + pos_embed_weight[positions] # SIMPLE SOLUTION (preferred)
    one_hot = torch.nn.functional.one_hot(positions, num_classes=pos_embed_weight.size(0)).float()
    return tokens + torch.einsum('btp,pd->btd', one_hot, pos_embed_weight)

# Test
tokens = torch.randn(2, 10, 64)
positions = torch.arange(10).unsqueeze(0).repeat(2, 1)
pos_embed_weight = torch.randn(20, 64)
assert torch.allclose(add_positional(tokens, positions, pos_embed_weight), add_positional_spec(tokens, positions, pos_embed_weight), atol=1e-5)

add_positional took 0.000494 seconds
add_positional_spec took 0.000808 seconds


In [7]:
7# Problem #5 - Multi-head QKV projection (Medium)
# Description: Transform input embeddings into multi-head query, key, and value projections simultaneously

@profile
def multi_head_qkv_spec(x, qkv_weight):
    batch_size, seq_len, embed_dim = x.shape
    _, three, num_heads, head_dim = qkv_weight.shape
    result = torch.zeros(batch_size, 3, num_heads, seq_len, head_dim)
    for b in range(batch_size):
        for t in range(seq_len):
            for i in range(embed_dim):
                for qkv in range(3):
                    for h in range(num_heads):
                        for d in range(head_dim):
                            result[b, qkv, h, t, d] += x[b, t, i] * qkv_weight[i, qkv, h, d]
    return result

@profile
def multi_head_qkv(x, qkv_weight):
    """Your solution here"""
    # batch_size, seq_len, d_embed = x.shape
    # d_embed, i, num_heads, head_dim = qkv_weight.shape
    return torch.einsum('btd,dinh -> binth', x, qkv_weight)


# Test
x = torch.randn(2, 4, 4) # B, T, D
qkv_weight = torch.randn(4, 3, 8, 16)  # 8 heads, 16 dim each
assert torch.allclose(multi_head_qkv(x, qkv_weight), multi_head_qkv_spec(x, qkv_weight), rtol=1e-3) # lower tolerance

multi_head_qkv took 0.000145 seconds
multi_head_qkv_spec took 1.495444 seconds


In [8]:
# Problem #6: Multi-Head Attention Output Reconstruction (Medium)
# Description: Concatenate multi-head attention outputs and apply an output projection.

@profile
def reconstruct_multihead_spec(multi_head_output, output_weight):
    batch_size, num_heads, seq_len, head_dim = multi_head_output.shape
    output_dim = output_weight.shape[1]
    concatenated = torch.zeros(batch_size, seq_len, num_heads * head_dim)
    for b in range(batch_size):
        for t in range(seq_len):
            for h in range(num_heads):
                for d in range(head_dim):
                    concatenated[b, t, h * head_dim + d] = multi_head_output[b, h, t, d]
    result = torch.zeros(batch_size, seq_len, output_dim)
    for b in range(batch_size):
        for t in range(seq_len):
            for i in range(num_heads * head_dim):
                for j in range(output_dim):
                    result[b, t, j] += concatenated[b, t, i] * output_weight[i, j]
    return result

@profile
def reconstruct_multihead(multi_head_output, output_weight):
    """Your solution here"""
    # batch_size, num_heads, seq_len, head_dim = multi_head_output
    # d_embed (num_heads x head_dim), out_channels = output_weight
    # output: batch_size, seq_len, out_channels
    # return torch.einsum('b', multi_head_output, output_weight)
    multi_head_output = rearrange(multi_head_output, 'b n t h -> b t (n h)') # batch_size, seq_len, d_embed
    return torch.einsum('btd, do -> bto', multi_head_output, output_weight)

# Test
multi_head_output = torch.randn(2, 8, 10, 16)  # 8 heads, 16 dim each
output_weight = torch.randn(128, 64)  # 8*16=128 input, 64 output
assert torch.allclose(reconstruct_multihead(multi_head_output, output_weight),
                      reconstruct_multihead_spec(multi_head_output, output_weight),
                      rtol=1e-3)

reconstruct_multihead took 2.458075 seconds
reconstruct_multihead_spec took 4.857121 seconds


In [9]:
# Problem #7: Multi-Head Attention Score Computation (Medium)
# Description: Compute attention scores for multi-head self-attention (batched, with multiple heads).

@profile
def mha_scores_spec(q, k, out):
    for b in range(q.shape[0]):
        for h in range(q.shape[1]):
            for i in range(q.shape[2]):
                for j in range(k.shape[2]):
                    dot = 0
                    for d in range(q.shape[3]):
                        dot += q[b][h][i][d] * k[b][h][j][d]
                    out[b][h][i][j] = dot

@profile
def mha_scores(q, k):
    """Your solution here"""
    return torch.einsum('bhqd,bhkd -> bhqk', q, k)

# Test
q = torch.randn(4, 2, 5, 8) # batch_size, num_head, q-seq_len, head_dim
k = torch.randn(4, 2, 7, 8) # batch_size, num_head, k-seq_len, head_dim
out_spec = torch.zeros(4, 2, 5, 7) # batch_size, num_head, q-seq_len, k-seq_len
mha_scores_spec(q, k, out_spec)
your_out = mha_scores(q, k)
assert torch.allclose(your_out, out_spec, atol=1e-5)


mha_scores_spec took 0.075232 seconds
mha_scores took 0.000425 seconds


In [10]:
# Problem #8: Scaled Dot-Product attention with causal mask (Hard)
# Description: Apply the lower triangular mask to compute causal attention

@profile
def causal_attention_spec(query, key, value, scale):
    batch_size, seq_len, head_dim = query.shape
    scores = torch.zeros(batch_size, seq_len, seq_len)
    for b in range(batch_size):
        for i in range(seq_len):
            for j in range(seq_len):
                if j <= i:
                    scores[b, i, j] = torch.dot(query[b, i], key[b, j]) * scale
                else:
                    scores[b, i, j] = float('-inf')
    attention = torch.softmax(scores, dim=-1)
    output = torch.zeros_like(value)
    for b in range(batch_size):
        for i in range(seq_len):
            for j in range(seq_len):
                output[b, i] += attention[b, i, j] * value[b, j]
    return output

@profile
def causal_attention(query, key, value, scale):
    """Your solution here"""
    # Compute Mask-attention
    # qkv = 3x(batch_size, seq_len, d_embed)
    batch_size, seq_len, d_embed = query.shape
    weights = torch.einsum('bqd, bkd -> bqk', query, key) * scale
    # apply causal mask
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    weights = weights.masked_fill(mask, float('-inf'))
    scores = torch.softmax(weights, dim=-1)
    out = torch.einsum('bqk, bkd -> bqd', scores, value)
    return out


# Test
query = torch.randn(2, 6, 32)
key = torch.randn(2, 6, 32)
value = torch.randn(2, 6, 32)
scale = 1/torch.sqrt(torch.tensor(32.0))
assert torch.allclose(causal_attention(query, key, value, scale),
                      causal_attention_spec(query, key, value, scale),
                      rtol=1e-3) # lower tolerance


causal_attention took 0.000658 seconds
causal_attention_spec took 0.003319 seconds


In [11]:
# Problem #9: Masked Attention Score Computation (Hard)
# Description: Compute batched attention scores with a given causal mask, and set masked-out positions to zero

@profile
def masked_attn_scores_spec(q, k, mask, out):
    for b in range(q.shape[0]):
        for i in range(q.shape[1]):
            for j in range(k.shape[1]):
                if mask[i][j] == 0:
                    out[b][i][j] = 0
                    continue
                dot = 0
                for d in range(q.shape[2]):
                    dot += q[b][i][d] * k[b][j][d]
                out[b][i][j] = dot
@profile
def masked_attn_scores(q, k, mask):
    """Your solution here"""
    # batch_size, seq_len, d_embed = q/k
    # seq_len, seq_len = mask
    qk = torch.einsum('bqd,bkd->bqk', q, k)
    return torch.einsum('bqk,qk->bqk', qk, mask)


# Test
seq = 5
q = torch.randn(4, seq, 8)
k = torch.randn(4, seq, 8)
mask = torch.tril(torch.ones(seq, seq))
out_spec = torch.zeros(4, seq, seq)
masked_attn_scores_spec(q, k, mask, out_spec)
your_out = masked_attn_scores(q, k, mask)
assert torch.allclose(your_out, out_spec, atol=1e-5)


masked_attn_scores_spec took 0.011909 seconds
masked_attn_scores took 0.000323 seconds


In [12]:
# Problem #10: Block Matmul (Hard)
# Description: Compute efficient block matmul, similar to the implementation in FlashAttention

@profile
def block_matmul_spec(a, b, out):
    """
    a: [num_q_blocks, num_k_blocks, block_q, d]
    b: [num_q_blocks, num_k_blocks, d, block_k]
    out: [num_q_blocks, num_k_blocks, block_q, block_k]
    """
    for qb in range(a.shape[0]):
        for kb in range(a.shape[1]):
            for i in range(a.shape[2]):
                for j in range(b.shape[3]):
                    dot = 0
                    for d in range(a.shape[3]):
                        dot += a[qb, kb, i, d] * b[qb, kb, d, j]
                    out[qb, kb, i, j] = dot
    return out

@profile
def block_matmul(a, b):
    """Your einsum solution here"""
    # num_q_blocks, num_k_blocks, q_block_size, d = a
    # num_q_blocks, num_k_blocks, d, k_block_size
    # final output = num_q_blocks, num_k_blocks, q_block_size, k_block_size
    return torch.einsum('qkQd,qkdK->qkQK', a, b)


num_q_blocks, num_k_blocks, block_q, block_k, d = 2, 3, 4, 5, 8
a = torch.randn(num_q_blocks, num_k_blocks, block_q, d)
b = torch.randn(num_q_blocks, num_k_blocks, d, block_k)
out_spec = torch.zeros(num_q_blocks, num_k_blocks, block_q, block_k)

block_matmul_spec(a, b, out_spec)
your_out = block_matmul(a, b)

assert torch.allclose(your_out, out_spec, atol=1e-5)

block_matmul_spec took 0.020145 seconds
block_matmul took 0.000239 seconds
