In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

EX1: The n-dimensional tensor mastery challenge: Combine the `Head` and `MultiHeadAttention` into one class that processes all the heads in parallel, treating the heads as another batch dimension (answer is in nanoGPT).

In [11]:
""" Global Variables """
block_size = 32
batch_size = 16
embed_dim = 32
num_heads = 4
num_blocks = 4

In [3]:
""" Create Dataset """
text = open("video_7_dependencies/input.txt").read().splitlines()

In [4]:
vocab = sorted(list(set("".join(text))))
print(len(vocab))
stoi = {token: idx for idx, token in enumerate(vocab)}
itos = {idx: token for token, idx in stoi.items()}
dataset = [stoi[t] for t in "".join(text)]

64


In [5]:
def create_batches(batch_size, mode = "train"):
    if mode == "train":
        idxs = torch.randint(0, int(len(dataset)*0.9) - block_size, size = (batch_size, ))
    else:
        idxs = torch.randint(int(len(dataset)*0.9), len(dataset) - block_size, size = (batch_size, ))
    X = []
    Y = []
    for idx in idxs:
        X.append(dataset[idx:idx+block_size])
        Y.append(dataset[idx + 1: idx + block_size + 1])
    X = torch.tensor(X).to(torch.long)
    Y = torch.tensor(Y)
    return X, Y

Below, I will create a "Transformer-like" architecture which uses two implementations of multi-head attention (iterations vs vectorized approach) and determine whether the outputs are the same. To do so, I fix the weights to the same set of values for the two implementations

In [6]:
""" Create Model """

class MaskedSelfAttention(nn.Module):
    def __init__(self, embed_dim, proj_dim, block_size):
        super().__init__()
        self.w_q = nn.Linear(embed_dim, proj_dim, bias = False)
        self.w_k = nn.Linear(embed_dim, proj_dim, bias = False)
        self.w_v = nn.Linear(embed_dim, proj_dim, bias = False)

        with torch.no_grad():
            self.w_q.weight = nn.Parameter(torch.arange(0, embed_dim*proj_dim).view(proj_dim, embed_dim).to(torch.float32))
            self.w_k.weight = nn.Parameter(torch.arange(0, embed_dim*proj_dim).view(proj_dim, embed_dim).to(torch.float32))
            self.w_v.weight = nn.Parameter(torch.arange(0, embed_dim*proj_dim).view(proj_dim, embed_dim).to(torch.float32))
            
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
        
    def forward(self, input):
        B, T, C = input.shape
        query = self.w_q(input) #B,T,proj_dim
        key = self.w_k(input) #B,T,proj_dim
        wei = (query @ key.permute(0,2,1))*(C**-0.5) #B,T,T
        wei = wei.masked_fill(self.tril[:T,:T] == 0, float("-inf")) #Note: subset the mask self.tril[:T,:T] in case input sequence is less then block_size; the mask needs to be broadcastable with wei
        wei = F.softmax(wei, dim = -1)

        value = self.w_v(input) #B,T,proj_dim
        out = wei @ value
        
        return out
    
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, embed_dim, block_size):
        super().__init__()
        self.multiple_attention = nn.ModuleList([MaskedSelfAttention(embed_dim, int(embed_dim/num_heads), block_size) for i in range(num_heads)])
    def forward(self, input):
        out = torch.cat([attn(input) for attn in self.multiple_attention], dim = -1) #B,T, embed_dim
        return out
    

class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, num_heads, embed_dim, block_size):
        super().__init__()
        self.num_heads = num_heads
        self.proj_dim = int(embed_dim/num_heads)
        self.w_q = nn.Linear(embed_dim, num_heads*self.proj_dim, bias = False)
        self.w_k = nn.Linear(embed_dim, num_heads*self.proj_dim, bias = False)
        self.w_v = nn.Linear(embed_dim, num_heads*self.proj_dim, bias = False)

        with torch.no_grad():
            self.w_q.weight = nn.Parameter(torch.arange(0, embed_dim*self.proj_dim).repeat(num_heads).view(num_heads*self.proj_dim, embed_dim).to(torch.float32))
            self.w_k.weight = nn.Parameter(torch.arange(0, embed_dim*self.proj_dim).repeat(num_heads).view(num_heads*self.proj_dim, embed_dim).to(torch.float32))
            self.w_v.weight = nn.Parameter(torch.arange(0, embed_dim*self.proj_dim).repeat(num_heads).view(num_heads*self.proj_dim, embed_dim).to(torch.float32))
            
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
        
    def forward(self, input):
        B, T, C = input.shape #C == embed_dim
        query = self.w_q(input).view(B, T, self.num_heads, self.proj_dim) #B, T, H, proj_dim
        key = self.w_k(input).view(B, T, self.num_heads, self.proj_dim) #B, T, H, proj_dim
        wei = (query.permute(0,2,1,3) @ key.permute(0,2,3,1))*(C**-0.5) #B, H, T, T
        wei = wei.masked_fill(self.tril[:T,:T] == 0, float("-inf")) 
        wei = F.softmax(wei, dim = -1)


        value = self.w_v(input).view(B, T, self.num_heads, self.proj_dim) #B, T, H, proj_dim
        out = wei @ value.permute(0,2,1,3) #B, H, T, proj_dim
        out = out.permute(0,2,1,3).contiguous().view(B,T,C) #B, T, C

        return out

In [7]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, block_size, num_heads):
        super().__init__()
        self.content_embedding = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(block_size, embed_dim)
        self.iterable_attention = MultiHeadAttention(num_heads, embed_dim, block_size)
        self.tensor_attention = MaskedMultiHeadAttention(num_heads, embed_dim, block_size)

    def forward(self, input):
        B, T = input.shape
        con_embed = self.content_embedding(input) #B,T,embed_dim
        pos_embed = self.position_embedding(torch.arange(T)) #1,T,embed_dim
        x = con_embed + pos_embed #B,T,embed_dim
        x1 = self.iterable_attention(x)
        x2 = self.tensor_attention(x)
        return [x1, x2]

In [8]:
model = Transformer(len(vocab), embed_dim, block_size, num_heads)

In [12]:
x = create_batches(batch_size)[0]

In [14]:
iterable_approach, tensor_approach = model(x)
print(iterable_approach.shape)
print(iterable_approach.sum())

print(tensor_approach.shape)
print(tensor_approach.sum())

print(iterable_approach == tensor_approach)

torch.Size([16, 32, 32])
tensor(-16841928., grad_fn=<SumBackward0>)
torch.Size([16, 32, 32])
tensor(-16841928., grad_fn=<SumBackward0>)
tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, T

NEED TO BE CAREFUL WITH VIEW: was creating a bug where going from B,T,num_heads x n_proj -> B,num_heads,T,n_proj was not generating an equivalent result to the iterable variant of MHSA. Instead first veiwing it to
B,T,num_heads,n_proj and then permuting it resolved the issues

Nice thread which higlights problems with using view for dimension swapping: https://discuss.pytorch.org/t/for-beginners-do-not-use-view-or-reshape-to-swap-dimensions-of-tensors/75524


Same issues when going from B,num_heads,T,n_proj -> B,T,num_heads x n_proj when generating the output of MHSA, need to first permute to  B,T,num_heads,n_proj then convert to contigous tensor (needed to "view" a permuted tensor from more to less dims; note this creates a copy of the non-continguous tensor as a contigious tensor), then view to B,T,num_heads x n_proj