In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [None]:
class ScaledDotProductAttention(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, q, k, v, mask=None):
    d_k = k.size(-1) # Q and K have shape (batch_size, n heads, seq_len, head dim)
    scores = q@k.transpose(-2, -1) / math.sqrt(d_k) # (batch_size, n heads, q_len, k_len)
    if mask is not None:
      scores = scores.masked_fill(mask == 0, float('-inf'))
    attn = F.softmax(scores, dim=-1) # use last dim because we want to sum across keys to get probs for each query
    attn = attn@v # (batch_size, n_heads, q_len, head dim)
    return attn

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, n_heads):
    super().__init__()
    self.n_heads = n_heads
    self.d_head = d_model // n_heads
    self.W_q = nn.Linear(d_model, d_model) # d model is the embedding dimension
    self.W_k = nn.Linear(d_model, d_model)
    self.W_v = nn.Linear(d_model, d_model)
    self.W_o = nn.Linear(d_model, d_model)

  def forward(self, q, k, v, mask=None):
    q = self.W_q(q) # q @ W.T, so shape of all of these is (batch size, seq len, embedding dim)
    k = self.W_k(k)
    v = self.W_v(v)

    batch_size_q, seq_len_q, d_model_q = q.shape # seq len means # of tokens
    batch_size_k, seq_len_k, d_model_k = k.shape
    batch_size_v, seq_len_v, d_model_v = v.shape

    # Now I need to perform multiplication across all heads, so I should reshape
    # my q and k matrices into (batch size, n_heads, seq_len, head dimension)
    # But we also need to give each feature its own axis
    # Reshape itself will only give you the right # of elements in the right shape, but not the correct layout
    q = q.reshape(batch_size_q, seq_len_q, self.n_heads, self.d_head).transpose(1, 2)
      # The original shape of Q here is (batch size, seq len, embedding dim)
      # We want to split the embedding dim into n_heads x d_head, or dimension of each head for a number of heads
    k = k.reshape(batch_size_k, seq_len_k, self.n_heads, self.d_head).transpose(1, 2)
    v = v.reshape(batch_size_v, seq_len_v, self.n_heads, self.d_head).transpose(1, 2)

    attn = ScaledDotProductAttention()(q, k, v, mask) # Shape (batch size, # of heads, seq_len, # embedding dim per head)
    attn_concat = attn.transpose(1, 2).reshape(batch_size_q, seq_len_q, d_model_q) # return axes to original posns before reshaping
    mha_output = self.W_o(attn_concat)

    return mha_output # shape (batch size, seq len, embedding dimension)

In [None]:
class InputEmbedding(nn.Module):
  def __init__(self, vocab_size, d_model):
    super().__init__()
    self.d_model = d_model
    self.embedding = nn.Embedding(vocab_size, d_model)

  def forward(self, x):
    return self.embedding(x) * math.sqrt(self.d_model)

In [None]:
class LayerNorm(nn.Module):
  def __init__(self, d_model, eps=1e-6):
    super().__init__()
    self.gamma = nn.Parameter(torch.ones(d_model))
    self.beta = nn.Parameter(torch.zeros(d_model))
    self.eps = eps

  def forward(self, x):
    mean = x.mean(-1, keepdim=True) # keep dim to maintain tensor dimensionality
    std = x.std(-1, keepdim=True)
    return self.gamma * (x - mean) / (std + self.eps) + self.beta

# In batch normalization, each feature is normalized across all samples
# In layer normalization, all features are normalized across each sample

In [None]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_seq_len):
    super().__init__()
    encoding_matrix = torch.zeros(max_seq_len, d_model)
    position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(10000.0) / d_model))
    # Each index in positional encoding accounts for TWO dimensions in an embedding vector
    # What I do here is create the denominator term for the even set of indices,
    # then apply it to both even and odd dimensions in the embeding
    encoding_matrix[:, 0::2] = torch.sin(position*div_term) # even dims, start at 0 stepsize 2
    encoding_matrix[:, 1::2] = torch.cos(position*div_term) # odd dims, start at 1 stepsize 2
    self.register_buffer("pos_encoding", encoding_matrix) # No need to train positional encoding
  def forward(self, x):
    return x + self.pos_encoding[:x.size(1)] # Return positional encodings for the target sequence length


In [None]:
class FeedForward(nn.Module):
  def __init__(self, d_model, d_ff):
    super().__init__()
    self.linear1 = nn.Linear(d_model, d_ff)
    self.linear2 = nn.Linear(d_ff, d_model)

  def forward(self, x):
    z1 = self.linear1(x)
    a1 = F.relu(z1)
    z2 = self.linear2(a1)
    return z2

In [None]:
class EncoderLayer(nn.Module):
  def __init__(self, d_model, n_heads, d_ff, dropout_rate):
    super().__init__()
    self.mha = MultiHeadAttention(d_model, n_heads)
    self.ff = FeedForward(d_model, d_ff)
    self.layer1norm = LayerNorm(d_model)
    self.layer2norm = LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout_rate)

  def forward(self, x, src_mask):
    x = self.layer1norm(x + self.dropout(self.mha(x, x, x, src_mask)))
    x = self.layer2norm(x + self.dropout(self.ff(x)))
    return x


In [None]:
class DecoderLayer(nn.Module):
  def __init__(self, d_model, n_heads, d_ff, dropout_rate):
    super().__init__()
    self.masked_attention = MultiHeadAttention(d_model, n_heads)
    self.cross_attention = MultiHeadAttention(d_model, n_heads)
    self.ff = FeedForward(d_model, d_ff)
    self.norm1 = LayerNorm(d_model)
    self.norm2 = LayerNorm(d_model)
    self.norm3 = LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout_rate)

  def forward(self, x, encoder_output, tgt_mask):
    x = self.norm1(x + self.dropout(self.masked_attention(x, x, x, tgt_mask)))# tgt mask is a causal mask for the decoder masked attention
    x = self.norm2(x + self.dropout(self.cross_attention(x, encoder_output, encoder_output))) # src mask is an optional mask applied to encoder output for cross attention
    x = self.norm3(x + self.dropout(self.ff(x)))
    return x

In [None]:
class Transformer(nn.Module):
  def __init__(self, d_model, n_heads, d_ff, n_layers, src_vocab_size, tgt_vocab_size, max_seq_len, dropout_rate, padding_index=0):
    super().__init__()

    self.padding_idx = padding_index

    self.encoder_embedding = InputEmbedding(src_vocab_size, d_model)
    self.decoder_embedding = InputEmbedding(tgt_vocab_size, d_model)

    self.positional_encoding_enc = PositionalEncoding(d_model, max_seq_len)
    self.positional_encoding_dec = PositionalEncoding(d_model, max_seq_len)

    self.encoders = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout_rate) for _ in range(n_layers)])
    self.decoders = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout_rate) for _ in range(n_layers)])

    self.linear = nn.Linear(d_model, tgt_vocab_size)

  def forward(self, src, tgt):

    # Create target padding mask
    tgt_padding_mask = (tgt != self.padding_idx).unsqueeze(1).unsqueeze(2)  # shape: (batch_size, 1, 1, tgt_seq_len)

    # Create target causal mask
    tgt_len = tgt.size(1)
    causal_mask = torch.tril(torch.ones(tgt_len, tgt_len)).to(tgt.device)  # shape: (tgt_len, tgt_len)
    causal_mask = causal_mask.unsqueeze(0).unsqueeze(1)  # shape: (1, 1, tgt_len, tgt_len)

    # Combine both
    tgt_mask = tgt_padding_mask.type(torch.bool) & causal_mask.type(torch.bool)

    # source padding mask
    src_padding_mask = (src != self.padding_idx).unsqueeze(1).unsqueeze(2) # source mask creates a mask that ignores padding tokens
    # unsqueeze 1, 2 since we want to match key matrix

    enc_input = self.encoder_embedding(src)
    dec_input = self.decoder_embedding(tgt)

    enc_input = self.positional_encoding_enc(enc_input) # shape batch size x src seq len x embedding dim
    dec_input = self.positional_encoding_dec(dec_input) # shape batch size x tgt seq len x embedding dim

    for encoder in self.encoders:
      enc_input = encoder(enc_input, src_padding_mask) # each encoder builds off prev encoder, so we reassign to enc input, same for decoder

    for decoder in self.decoders:
      dec_input = decoder(dec_input, enc_input, tgt_mask)

    logits = self.linear(dec_input)
    return logits # Return logits here because in training cross entropy loss expects logits

In [None]:
model = Transformer(
    d_model=512,
    n_heads=8,
    d_ff=2048,
    n_layers=6,
    src_vocab_size=10000,
    tgt_vocab_size=10000,
    max_seq_len=100,
    dropout_rate=0.1,
    padding_index=0
)

batch_size = 2
src_seq_len = 10
tgt_seq_len = 9

src = torch.randint(4, 1000, (batch_size, src_seq_len))
tgt = torch.randint(4, 1000, (batch_size, tgt_seq_len))

src[:, -2:] = 0
tgt[:, -1:] = 0

logits = model(src, tgt)

In [None]:
def test_transformer_forward_pass_doesnt_crash():
    model = Transformer(
        d_model=512,
        n_heads=8,
        d_ff=2048,
        n_layers=2,
        src_vocab_size=1000,
        tgt_vocab_size=1000,
        max_seq_len=50,
        dropout_rate=0.1
    )

    src = torch.randint(1, 1000, (4, 32))  # batch size 4, src seq len 32
    tgt = torch.randint(1, 1000, (4, 16))  # batch size 4, tgt seq len 16

    out = model(src, tgt)
    assert out.shape == (4, 16, 1000), f"Unexpected output shape: {out.shape}"

test_transformer_forward_pass_doesnt_crash()

In [None]:
def test_padding_mask_blocks_attention():
    model = Transformer(
        d_model=128,
        n_heads=4,
        d_ff=512,
        n_layers=1,
        src_vocab_size=50,
        tgt_vocab_size=50,
        max_seq_len=10,
        dropout_rate=0.0
    )

    pad_idx = 0
    model.padding_idx = pad_idx

    src = torch.tensor([
        [1, 2, 3, 4, pad_idx, pad_idx],  # sequence with padding
        [5, 6, 7, 8, 9, 10],             # no padding
    ])
    tgt = torch.tensor([
        [1, 2, 3],
        [4, 5, 6],
    ])

    with torch.no_grad():
        out = model(src, tgt)

    assert not torch.isnan(out).any(), "Model output contains NaNs. Probably broken masking."

test_padding_mask_blocks_attention()

In [None]:
def test_causal_mask_blocks_future():
    model = Transformer(
        d_model=64,
        n_heads=4,
        d_ff=128,
        n_layers=1,
        src_vocab_size=100,
        tgt_vocab_size=100,
        max_seq_len=5,
        dropout_rate=0.0
    )

    src = torch.randint(1, 100, (1, 5))  # batch size 1
    tgt = torch.randint(1, 100, (1, 5))

    with torch.no_grad():
        output = model(src, tgt)  # shape: (1, 5, vocab_size)

    # Compare logits: position 1 should not be influenced by position 2
    # (We'd need to dig into the attention weights to really prove this, but...)
    assert output.shape == (1, 5, 100), "Unexpected output shape in causal mask test."

test_causal_mask_blocks_future()

In [None]:
def test_mask_shapes_align():
    tgt_len = 6
    tgt_mask = torch.tril(torch.ones(tgt_len, tgt_len))
    expanded = tgt_mask.unsqueeze(0).unsqueeze(1)  # shape should be (1, 1, tgt_len, tgt_len)
    assert expanded.shape == (1, 1, tgt_len, tgt_len), f"Mask shape is wrong: {expanded.shape}"


In [None]:
src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1

transformer = Transformer(
    d_model=d_model,
    n_heads=num_heads,
    d_ff=d_ff,
    n_layers=num_layers,
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    max_seq_len=max_seq_length,
    dropout_rate=dropout
)

src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

In [None]:
loss_fn = nn.CrossEntropyLoss(ignore_index=0) # ignore padding tokens
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

for epoch in range(10):
  optimizer.zero_grad()
  output = transformer(src_data, tgt_data[:, :-1]) # input is all tokens except last, according to paper
  # Target is all tokens except the first one (shifted right)
  target = tgt_data[:, 1:].contiguous().view(-1) # understand this
  output = output.contiguous().view(-1, output.size(-1)) # understand this
  loss = loss_fn(output, target)
  loss.backward()
  optimizer.step()
  print(f"Epoch {epoch+1}, Loss: {loss.item()}")


KeyboardInterrupt

