In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Encoder

## Attention

In [None]:
def single_attention(query, key, value, mask=None):

  key_dim = query.size(-1)
  scores = torch.bmm(query, key.transpose(1, 2)) / math.sqrt(key_dim) # Matrix multiplication

  if mask is not None: # Hide inputs for decoder
    scores = scores.masked_fill(mask == 0, float("-inf"))

  attention = F.softmax(scores, dim=-1)

  return torch.bmm(attention, value)

In [None]:
class AttentionHead(nn.Module):

  def __init__(self, embed_dim, head_dim, mask=None):
    # Head dim in practice is multiple of embed_dim to keep constant computation across each head
    # Ex: BERT has 12 heads each w/ dimension 768/12 = 64


    super().__init__()

    self.query = nn.Linear(embed_dim, head_dim, bias=False)
    self.key = nn.Linear(embed_dim, head_dim, bias=False)
    self.value = nn.Linear(embed_dim, head_dim, bias=False)

    self.mask = mask

  def forward(self, hidden_state):
    return single_attention(self.query(hidden_state), self.key(hidden_state), self.value(hidden_state), self.mask)

In [None]:
class MultiHeadAttention(nn.Module):
  """
  Concatenate outputs of single attention heads to produce final output

  Config is based on transformers.AutoConfig
  """

  def __init__(self, config, mask=None):
    super().__init__()

    embed_dim = config.hidden_size
    num_heads = config.num_attention_heads
    head_dim = embed_dim // num_heads

    self.heads = nn.ModuleList(
        [AttentionHead(embed_dim, head_dim, mask) for _ in range(num_heads)]
    )

    self.output_linear = nn.Linear(embed_dim, embed_dim)

  def forward(self, hidden_state):
    x = torch.cat([h(hidden_state) for h in self.heads], dim=-1)
    x = self.output_linear(x)
    return x

## Feed-forward

In [None]:
class FeedForwardNetwork(nn.Module):

  def __init__(self, config):
    super().__init__()

    self.linear_1 = nn.Linear(config.hidden_size, config.intermediate_size)
    self.linear_2 = nn.Linear(config.intermediate_size, config.hidden_size)

    self.gelu = nn.GELU() # Theorized to be best activation function for transformers

    self.dropout = nn.Dropout(config.hidden_dropout_prob)

  def forward(self, x):
    x = self.linear_1(x)

    x = self.gelu(x)

    x = self.linear_2(x)

    x = self.dropout(x)

    return x

## Layer Normalization

Layer normalization: Normalize each input in the batch to have 0 mean and same variance

Skip connections: Pass a tensor to the next layer of model w/o processing

In [None]:
# Use pre-layer normalization

# Layer Norm --> Multi-head Attention --> Layer Norm --> Feed forward

class EncoderLayer(nn.Module):
  def __init__(self, config):
    self.layer_norm_1 = nn.LayerNorm(config.hidden_size)
    self.layer_norm_2 = nn.LayerNorm(config.hidden_size)

    self.attention = MultiHeadAttention(config)
    self.feed_forward = FeedForwardNetwork(config)

  def forward(self, x):
    hidden_state = self.layer_norm_1(x) # First layer normalization
    x = x + self.attention(hidden_state) # Attention w/ skip connection
    x = x + self.feed_forward(self.layer_norm_2(x)) # Feed forward & next layer normaliation w/ skip connection

    return x

## Positional Embeddings

Incorporate positional information (attention is just weighted sum, so it can't do this by itself)

In [None]:
# Augment token embeddings w/ position index (can be learned by attention layer & feed forward network)

class PositionalEmbedding(nn.Module):

  def __init__(self, config):
    super().__init__()

    self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
    self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)

    self.layer_norm = nn.LayerNorm(config.hidden_size)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)

  def forward(self, input_ids):

    seq_length = input_ids.size(-1)
    position_ids = torch.arange(seq_length, dtype=torch.long).unsqueeze(0) # (1, seq_length)

    token_embeddings = self.token_embeddings(input_ids) # (batch_size, seq_length, hidden_size)
    position_embeddings = self.position_embeddings(position_ids) # (1, seq_length, hidden_size)

    embeddings = token_embeddings + position_embeddings

    embeddings = self.layer_norm(embeddings)
    embeddings = self.dropout(embeddings)

    return embeddings

In [None]:
class Encoder(nn.Module):

  def __init__(self, config):
    super().__init__()

    self.embeddings = PositionalEmbedding(config)
    self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.num_hidden_layers)])

  def forward(self, x):
    x = self.embeddings(x)

    for layer in self.layers:
      x = layer(x)

    return x # Get a hidden state for each token in the batch

## Text Classification

In [None]:
class SequenceClassification(nn.Module):
  """
  Need to define config.num_labels before using
  """

  def __init__(self, config):
    super().__init__()

    self.encoder = Encoder(config)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)
    self.classifier = nn.Linear(config.hidden_size, config.num_labels)

  def forward(self, x):
    x = self.encoder(x)[:, 0, :] # Select hidden state of tokens w/ []
    x = self.dropout(x)
    x = self.classifier(x)

    return x

# Decoder

## Attention

Need to mask inputs to make sure that decoder is actually learning and not just copying (why mask is in previous attention stuff)

Also need encoder-decoder attention layer to learn how to relate tokens from 2 different sequences

In [None]:
# Example mask

ex_seq_len = inputs.input_ids.size(-1)
ex_mask = torch.tril(torch.ones(ex_seq_len, ex_seq_len)).unsqueeze(0) # Creates lower triangular matrix

In [None]:
class EncoderDecoderAttention(nn.Module):
  """
  Encoder-decoder attention layer.
  This allows the decoder to attend to encoder outputs.
  """

  def __init__(self, config):
    super().__init__()
    self.self_attention = MultiHeadAttention(config)

  def forward(self, decoder_hidden_state, encoder_hidden_state):
    return self.attention(decoder_hidden_state, encoder_hidden_state)

## Layer Normalization

In [None]:
# Use pre-layer normalization

# Layer Norm --> Multi-head Attention --> Encoder-decoder attention --> Layer Norm --> Feed forward

class DecoderLayer(nn.Module):
  """
  Decoder layer with masked self-attention, encoder-decoder attention, and feed-forward network.
  """

  def __init__(self, config, mask):
    self.layer_norm_1 = nn.LayerNorm(config.hidden_size)
    self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
    self.layer_norm_3 = nn.LayerNorm(config.hidden_size)

    self.attention = MultiHeadAttention(config, mask)
    self.encoder_decoder_attention = EncoderDecoderAttention(config)

    self.feed_forward = FeedForwardNetwork(config)

  def forward(self, x, encoder_output):
    hidden_state = self.layer_norm_1(x) # First layer normalization
    x = x + self.attention(hidden_state) # Masked attention w/ skip connection


    hidden_state = self.layer_norm_2(x) # Second layer normalization
    x = x + self.encoder_decoder_attention(hidden_state, encoder_output) # Encoder-decoder attention

    x = x + self.feed_forward(self.layer_norm_3(x)) # Feed forward & next layer normaliation w/ skip connection

    return x

## Positional Embeddings

In [None]:
class Decoder(nn.Module):

  def __init__(self, config):
    super().__init__()

    self.embeddings = PositionalEmbedding(config)

    self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])

  def forward(self, x, encoder_output):
    x = self.embeddings(x)

    # Generate mask dynamically based on sequence length
    seq_length = x.size(1)  # Get sequence length
    mask = self.create_mask(seq_length, x.device)  # Ensure mask is on the same device

    for layer in self.layers:
        x = layer(x, encoder_output, mask)  # Pass the mask to each decoder layer

    return x  # Return hidden states

  def create_mask(self, seq_length, device):
    return torch.tril(torch.ones(seq_length, seq_length, device=device)).unsqueeze(0) # Create lower triangular matrix