In [32]:
import torch
from torch.nn import functional as F
from torch import nn

print(torch.__version__)

2.0.0


In [120]:
b = 32 # batch size
t = 24 # max sequence length (number of tokens)
c = 64 # embedding dimensions

In [121]:
x = torch.rand(size=(b, t, c))

In [126]:
class Attention(nn.Module):
    def __init__(self, emb_dim, head_size, masked=False):
        super().__init__()
        self.emb_dim = emb_dim
        self.masked = masked

        #TODO: Check if these projections should have bias or not
        self.toquery = nn.Linear(emb_dim, head_size)
        self.tokey = nn.Linear(emb_dim, head_size)
        self.tovalue = nn.Linear(emb_dim, head_size)

    def forward(self, x):
        b, t, c = x.size()
        # Project input into query, key, and value
        Q = self.toquery(x) # b, t, head_size
        K = self.tokey(x) # b, t, head_size
        V = self.tovalue(x) # b, t, head_size

        # transpose K to swap the second-to-last with the last dimension before matrix multiplication
        att = Q @ K.transpose(-2,-1) # (b, t, head_size) @ (b, head_size, t) = b, t, t
        att_scaled = att / (self.emb_dim ** 0.5)

        # Apply masking to allow tokens to only attend to the left, not to the right
        if self.masked:
            mask = torch.tril(torch.ones(t, t))
            att_scaled = att_scaled.masked_fill(mask == 0, float('-inf'))

        # Softmax scores to get weights
        weights = F.softmax(att_scaled, dim=-1) # b, t, t

        # Multiply softmaxed weights with values
        out = weights @ V # (b, t, t) @ (b, t, head_size) = b, t, head_size

        return out

In [127]:
a = Attention(c, int(c/4), masked=True)
p = a(x)
p.shape

torch.Size([32, 24, 16])

In [128]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_dim, n_heads, masked=False):
        super().__init__()

        self.emb_dim = emb_dim
        self.n_heads = n_heads

        # Embedding dimension must be divisble by number of heads
        assert emb_dim % n_heads == 0
        head_size = emb_dim // n_heads

        self.heads = nn.ModuleList([
            Attention(emb_dim, head_size, masked) for _ in range(n_heads)
        ])

        self.unifyheads = nn.Linear(emb_dim, emb_dim, bias=False)

    def forward(self, x):
        # Pass input through each attention head
        out = [h(x) for h in self.heads] # n_head x [b, t, head_size]
        # Concatenate outputs from individual heads back along the last (embedding) dimension
        out = torch.cat(out, dim=-1) # b, t, emb_dim
        # Pass concatenated output from all heads through linear layer
        out = self.unifyheads(out) # # b, t, emb_dim

        return out

In [129]:
mha = MultiHeadAttention(c, 4, None)
out = mha(x)
print(out.shape)

torch.Size([32, 24, 64])


In [130]:
class TransformerBlock(nn.Module):
    def __init__(self, emb_dim, ff_dim, n_heads, masked=False):
        super().__init__()
        self.mha = MultiHeadAttention(emb_dim, n_heads, masked)
        self.ff = nn.Sequential(
            nn.Linear(emb_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, emb_dim)
        )
        self.ln1 = nn.LayerNorm(emb_dim)
        self.ln2 = nn.LayerNorm(emb_dim)

    def forward(self, x):
        # Attention
        x = x + (self.mha(self.ln1(x)))
        # Feed-forward
        x = x + (self.ff(self.ln2(x)))

        return x


In [131]:
tb = TransformerBlock(emb_dim=c, ff_dim=c * 4, n_heads=8)
out = tb(x)
out.shape

torch.Size([32, 24, 64])

In [132]:
class TransformerEncoder(nn.Module):
    def __init__(self, emb_dim, ff_dim, vocab_size, max_seq_len, n_layers, n_heads, n_classes, masked=None):
        super().__init__()

        self.token_embeddings = nn.Embedding(vocab_size, emb_dim)
        self.positional_embeddings = nn.Embedding(max_seq_len, emb_dim)

        self.transformer_blocks = nn.Sequential(*[
            TransformerBlock(emb_dim, ff_dim, n_heads, masked) for _ in range(n_layers)
        ])

        self.cls_layer = nn.Linear(emb_dim, n_classes)

    def forward(self, x):
        # Retrieve token embeddings
        embeddings = self.token_embeddings(x)
        b, t, c = embeddings.size()

        # Generate array from 0 to sequence length to retrieve positional embeddings
        positions = torch.arange(t)
        pos_embeddings = self.positional_embeddings(positions)

        # Add token embeddings and positional encodings to obtain input into transformer blocks
        input = embeddings + pos_embeddings

        # Run data through transformer blocks
        output = self.transformer_blocks(input)

        # Take average across tokens
        output = output.mean(dim=1) # b, t, emb_dim -> b, emb_dim

        # Final linear layer to obtain one value per class
        logits = self.cls_layer(output)

        return logits



In [133]:
b = 32
t = 24
vocab_size = 1000
emb_dim = 128
ff_dim = emb_dim * 4
n_layers = 8
n_heads = 8
n_classes = 10

tokens = torch.randint(low=0, high=vocab_size, size=(b, t))

In [134]:
te = TransformerEncoder(emb_dim, ff_dim, vocab_size, t, n_layers, n_heads, n_classes)
out = te(tokens)
out.shape

torch.Size([32, 10])

In [119]:
T = 8
tril = torch.tril(torch.ones(T, T))
print(tril)
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
print(wei)
sm = F.softmax(wei, dim=-1)
print(sm)

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000,