<a href="https://colab.research.google.com/github/alexchen1999/deeplearning-from-scratch/blob/main/transformer_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
import torch
import torch.nn as nn

In [47]:
class MultiHeadAttention(nn.Module):
  '''
  Split an embedding into different parts

  E.g. if we had an embedding size of 256 and 8 attention heads then the
  input embedding would be split into 8 parts of size 32

  '''
  def __init__(self, embedding_size, n_heads):
    super(MultiHeadAttention, self).__init__()
    self.embedding_size = embedding_size
    self.n_heads = n_heads
    self.head_dimension = embedding_size // n_heads # integer division

    # Q, K, V
    self.values = nn.Linear(self.head_dimension, self.head_dimension, bias=False) # attention matrix --> maps head dimension to head dimension for split embedding
    self.keys = nn.Linear(self.head_dimension, self.head_dimension, bias=False)
    self.queries = nn.Linear(self.head_dimension, self.head_dimension, bias=False)

    self.fc_out = nn.Linear(self.n_heads * self.head_dimension, self.embedding_size) # mutiplication between n_heads and head_dimension is some convention to make it clear we're concatenating the heads back to the original embedding size

  def forward(self, keys, values, query, mask):
    batch_size = query.shape[0] # How many examples we send at once
    value_length, key_length, query_length = values.shape[1], keys.shape[1], query.shape[1]

    # Objective is to reshape the Q,K,V tensors so that attention can be applied to multiple heads
    # Original input size is (batch_size, length, embedding_size) for each I believe (see: https://pytorch.org/docs/stable/generated/torch.reshape.html)
    # split embedding into n_heads --> does this mean that the embedding size must be cleanly divisible by the number of heads
    values = values.reshape(batch_size, value_length, self.n_heads, self.head_dimension)
    keys = keys.reshape(batch_size, key_length, self.n_heads, self.head_dimension)
    query = query.reshape(batch_size, query_length, self.n_heads, self.head_dimension)

    # Objective: Multiply queries with the keys in attention formula, but need to prepare tensor dimensions for batch matrix multiplication.
    # One way to do it easily is with einsum. The einsum operation "nqhd,nkhd->nhqk" essentially computes the dot product between every query and key pair across all batches and heads.
    # QK = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
    # (batch_size, query_length, n_heads, head_dimension)
    # (batch_size, key_length, n_heads, head_dimension)
    # Output of this matmul is (batch_size, n_heads, query_length, key_length)

    # Equivalently one could manually emulate einsum operation:
    # Step 1: Transpose the heads and sequence length dimensions to align them for batch matrix multiplication
    query = query.transpose(1, 2)  # New shape: [batch_size, n_heads, seq_len, head_dim]
    keys = keys.transpose(1, 2)  # Also, [batch_size, n_heads, seq_len, head_dim]

    # Before performing bmm, keys need to be transposed so that the last two dimensions are [head_dim, seq_len]
    # We're interested in computing the attention between the sequence lengths of queries and keys which is why we need to transpose with head_dim to satisfy inner product dimensions
    keys = keys.transpose(2, 3)  # New shape: [batch_size, n_heads, head_dim, seq_len]

    # Next have to combine batch size and heads into one dimension because torch.bmm expects a 3D input
    query_flattened = query.reshape(batch_size * self.n_heads, query_length, self.head_dimension)
    keys_flattened = keys.reshape(batch_size * self.n_heads, self.head_dimension, key_length)

    # Step 2: Perform batch matrix multiplication
    # Output shape after bmm: [batch_size, n_heads, seq_len (from queries), seq_len (from keys)]
    # Energy/QK intepretation: for each word in our target (query) sentence, how much should we pay attention to each word in our source sentence
    QK = torch.bmm(query_flattened, keys_flattened)

    # Unflatten to get back batch size and n_heads
    QK = QK.reshape(batch_size, self.n_heads, query_length, key_length)

    # Mask certain elements with "-inf" (in practice a very small number)
    # https://pytorch.org/docs/stable/generated/torch.Tensor.masked_fill_.html#torch.Tensor.masked_fill_
    if mask is not None:
      QK = QK.masked_fill_(mask == 0, float("-1e20"))

    # Compute attention across the key length
    attention = torch.softmax(QK / (self.embedding_size ** (1/2)), dim=3)

    # attention = [batch_size, n_heads, query_length, key_length]
    # values = [batch_size, value_length, n_heads, head_dimension]
    # output = [batch_size, query_length, n_heads, head_dimension]

    attention_flattened = attention.reshape(batch_size * self.n_heads, query_length, key_length)
    values_flattened = values.reshape(batch_size * self.n_heads, value_length, self.head_dimension)

    # batch matmul
    output_flattened = torch.bmm(attention_flattened, values_flattened)
    output = output_flattened.reshape(batch_size, query_length, self.n_heads * self.head_dimension)

    output = self.fc_out(output)
    return output

In [11]:
# Testing a la torch.bmm example: https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm

batch_size = 3
n_heads = 2
head_dimension = 10
query_length = 256 # Some example 256 word query sentence
key_length = 32
value_length = 32 # Key length and value length should be the same since they're pairs

'''
torch.bmm expects both tensors to be 3D with shapes (batch_size, *, *), where the first dimension is the batch size,
and the next two dimensions are the matrix dimensions that will be multiplied together. In a multi-head attention scenario,
you have an additional "heads" dimension, so you'll need to reshape your tensors to temporarily combine the batch and heads
dimensions into one, perform the batch matrix multiplication, then reshape the result back to separate the batch and heads dimensions.

'''

attention = torch.randn(batch_size, n_heads, query_length, key_length)
values = torch.randn(batch_size, value_length, n_heads, head_dimension)

attention_flattened = attention.reshape(batch_size * n_heads, query_length, key_length)
values_flattened = values.reshape(batch_size * n_heads, value_length, head_dimension)

print(attention.shape)
print(values.shape)

print(attention_flattened.shape)
print(values_flattened.shape)

output_flattened = torch.bmm(attention_flattened, values_flattened)
print(output_flattened.shape)

output = output_flattened.reshape(batch_size, query_length, n_heads, head_dimension)
print(output.shape)

torch.Size([3, 2, 256, 32])
torch.Size([3, 32, 2, 10])
torch.Size([6, 256, 32])
torch.Size([6, 32, 10])
torch.Size([6, 256, 10])
torch.Size([3, 256, 2, 10])


In [12]:
print(attention_flattened)

tensor([[[ 0.8661, -0.3334,  2.1608,  ...,  0.0204, -0.8349,  0.4355],
         [-0.1408, -0.2549, -0.9407,  ...,  1.2538,  0.5239, -0.5087],
         [ 0.9114, -0.0209, -1.1141,  ...,  2.1751, -0.2217,  0.3731],
         ...,
         [ 0.7588, -0.1412, -1.0503,  ...,  1.2238, -0.7840, -0.3918],
         [-0.0598, -0.6259, -1.5142,  ...,  0.7198, -1.4527,  0.0581],
         [ 0.6848,  1.0335, -0.5327,  ..., -1.0223, -0.6783,  0.3035]],

        [[ 1.5701,  0.0271, -0.0106,  ..., -0.1874,  0.2687, -0.7438],
         [-0.0702, -0.6042,  0.0681,  ..., -0.3650, -0.1657,  0.3115],
         [-0.2184,  1.3624,  0.3536,  ..., -0.3137,  0.2350, -1.0870],
         ...,
         [ 0.5249, -0.6473,  0.5684,  ..., -0.1210,  0.7735,  0.1469],
         [ 0.9371, -1.0981, -0.8474,  ...,  0.7108, -2.0175, -1.4260],
         [ 0.4808, -1.7110, -0.2402,  ...,  0.0283,  0.5715,  0.0600]],

        [[-0.2092, -0.9602, -0.0471,  ..., -0.8422, -0.3692,  0.1834],
         [-1.1549, -1.3857,  1.4099,  ...,  2

In [13]:
print(output)

tensor([[[[ -0.1981,  19.6662,   2.3831,  ...,  -0.6509, -18.2371,  -9.5532],
          [ -8.7651,  -6.2180,  -2.6235,  ..., -10.6252,   6.7315,   8.9361]],

         [[ -0.6194,  -2.7457,   1.3941,  ...,  -1.6666,   7.6069,   5.1755],
          [ 12.5339,  -4.0733,   0.7828,  ...,  -7.7826,   4.9962, -11.8000]],

         [[  4.8412,  -2.6500,   3.8612,  ...,   5.8582,  -0.1182,   0.5946],
          [  1.2622,   5.7494,  -0.3883,  ...,   7.0686, -10.2198,  -2.7761]],

         ...,

         [[ -0.2550,  -7.1119,  -3.6585,  ...,  -3.2007,   0.4311,  -5.5778],
          [ -2.7743,  -2.8833,  -7.1750,  ...,  -5.6871,  -5.8650,  -0.4186]],

         [[ -6.9740,  -4.3421,  -1.2973,  ...,   1.4527,  -2.8657,  -8.5665],
          [ -2.1417,   4.9784,   2.3405,  ...,  -4.5241,  -7.0783,   9.0825]],

         [[ -1.6823,   2.0054,   7.2486,  ...,   3.0606,   5.0798,   1.0843],
          [  0.9063,   5.0518,   0.1211,  ..., -10.1762, -10.4950,  -0.0917]]],


        [[[ -5.7294,   2.8302,  -9.

In [48]:
class TransformerBlock(nn.Module):
  def __init__(self, embedding_size, n_heads, dropout, forward_expansion):
    super(TransformerBlock, self).__init__()

    print(embedding_size, n_heads)

    self.mha = MultiHeadAttention(embedding_size, n_heads)
    self.norm1 = nn.LayerNorm(embedding_size) # https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html, reference: https://arxiv.org/abs/1607.06450
                                              # SImilar to torch.nn.BatchNorm2D -- only difference is BatchNorm normalizes within a batch of samples whereas LayerNorm is just within a sample.
    self.feed_forward = nn.Sequential(
        nn.Linear(embedding_size, forward_expansion * embedding_size), # map the embedding size to the forward expansion times the embedding size; in original paper value of forward expansion is 4x.
        nn.ReLU(),
        nn.Linear(forward_expansion * embedding_size, embedding_size)  # map it back to original embedding size.
    )

    self.norm2 = nn.LayerNorm(embedding_size)
    self.dropout = nn.Dropout(dropout)


  def forward(self, keys, values, query, mask):
    mha = self.mha(values, keys, query, mask) # How does this line make sense
    norm1 = self.dropout(self.norm1(mha + query)) # Add a residual connection and dropout
    feed_forward = self.feed_forward(norm1)
    output = self.dropout(self.norm2(feed_forward + norm1)) # Add another residual connection
    return output


In [50]:
class Encoder(nn.Module):
  '''
  max_sentence_length is related to the positional embedding. Send in how long is the max sentence length


  '''
  def __init__(self, src_vocab_size, max_sentence_length, embedding_size, n_layers, n_heads, device, forward_expansion, dropout):
    super(Encoder, self).__init__()

    self.word_embedding = nn.Embedding(src_vocab_size, embedding_size) # [src_vocab_size, embedding_size]
    self.positional_embedding = nn.Embedding(max_sentence_length, embedding_size) # [max_sentence_length, embedding_size]

    print('encoder')
    print(embedding_size, n_heads, dropout, forward_expansion)

    self.layers = nn.ModuleList([
        TransformerBlock(embedding_size, n_heads, dropout=dropout, forward_expansion=forward_expansion)
    ])

    self.dropout = nn.Dropout(dropout)
    self.device = device

  def forward(self, x, mask):
    N, seq_length = x.shape
    positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device) # Positional embeddings: N x [0, 1, 2, ..., seq_length]

    output = self.dropout(self.word_embedding(x) + self.positional_embedding(positions))

    for layer in self.layers:
      output = layer(output, output, output, mask)

    return output


class DecoderBlock(nn.Module):
  def __init__(self, embedding_size, n_heads, forward_expansion, dropout, device):
    super(DecoderBlock, self).__init__()

    self.mha = MultiHeadAttention(embedding_size, n_heads)
    self.norm = nn.LayerNorm(embedding_size)
    self.transformer_block = TransformerBlock(embedding_size, n_heads, dropout, forward_expansion)
    self.dropout = nn.Dropout(dropout)


  def forward(self, x, value, key, src_mask, tgt_mask):
    attention = self.mha(x, x, x, tgt_mask)
    query = self.dropout(self.norm(x + attention))
    output = self.transformer_block(value, key, query, src_mask)
    return output


class Decoder(nn.Module):
  def __init__(self, tgt_vocab_size, embedding_size, n_layers, n_heads, forward_expansion, dropout, device, max_sentence_length):
      super(Decoder, self).__init__()

      self.device = device
      self.word_embedding = nn.Embedding(tgt_vocab_size, embedding_size)
      self.positional_embedding = nn.Embedding(max_sentence_length, embedding_size)

      self.layers = nn.ModuleList([DecoderBlock(embedding_size, n_heads, forward_expansion, dropout, device) for _ in range(n_layers)]) # N decoder blocks
      self.fc_out = nn.Linear(embedding_size, tgt_vocab_size)
      self.dropout = nn.Dropout(dropout)

  def forward(self, x, encoder_out, src_mask, tgt_mask):
    N, seq_length = x.shape
    positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
    x = self.dropout((self.word_embedding(x) + self.positional_embedding(positions)))

    for layer in self.layers:
      x = layer(x, encoder_out, encoder_out, src_mask, tgt_mask)

    output = self.fc_out(x)
    return output


In [52]:
class Transformer(nn.Module):
  def __init__(self, src_vocab_size, tgt_vocab_size, src_pad_idx, tgt_pad_idx, embedding_size=256, n_layers=6, forward_expansion=4, n_heads=8, dropout=0, device='cuda', max_sentence_length=100):

    super(Transformer, self).__init__()

    self.encoder = Encoder(src_vocab_size=src_vocab_size, embedding_size=embedding_size, n_layers=n_layers, n_heads=n_heads, device=device, forward_expansion=forward_expansion, dropout=dropout, max_sentence_length=max_sentence_length)
    self.decoder = Decoder(tgt_vocab_size=tgt_vocab_size, embedding_size=embedding_size, n_layers=n_layers, n_heads=n_heads, forward_expansion=forward_expansion, dropout=dropout, device=device, max_sentence_length=max_sentence_length)

    self.src_pad_idx = src_pad_idx
    self.tgt_pad_idx = tgt_pad_idx
    self.device = device

  def make_src_mask(self, src):
    src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
    return src_mask.to(self.device)

  def make_tgt_mask(self, tgt):
    N, tgt_length = tgt.shape
    tgt_mask = torch.tril(torch.ones((tgt_length, tgt_length))).expand(N, 1, tgt_length, tgt_length)
    return tgt_mask.to(self.device)

  def forward(self, src, tgt):
    src_mask = self.make_src_mask(src)
    tgt_mask = self.make_tgt_mask(tgt)

    src_encoder = self.encoder(src, src_mask)
    output = self.decoder(tgt, src_encoder, src_mask, tgt_mask)

    return output

In [53]:
x = torch.tensor([[1], [2], [3]])
print(x.size())
x.expand(3, 4)

torch.Size([3, 1])


tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]])

In [54]:
x = torch.tensor([[1], [2], [3]])
print(x)
x.expand(3, 5)

tensor([[1],
        [2],
        [3]])


tensor([[1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]])

In [55]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(
    device
)
trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)

src_pad_idx = 0
trg_pad_idx = 0
src_vocab_size = 10
trg_vocab_size = 10
model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(
    device
)
out = model(x, trg[:, :-1])
print(out.shape)

cpu
encoder
256 8 0 4
256 8
256 8
256 8
256 8
256 8
256 8
256 8
torch.Size([2, 7, 10])
