# Transformer Mechanisms

### Positional Encoding
Positional encoding provides positional information for each token in a sequence. These are required by the attention mechanisms. A positional encoding vector is calculated using the token embedding, and is made unique with the sine/cosine functions. The position information is the sum of the positional encoding vector and the token embedding.

### Attention Mechanism
#### Self attention
Self-attention provides the model with information on the relationship between words in a sequence. Query, Key and Value matrices are linear projections of the token embeddings, each trained with their own weights. The dot product of Query and Key matrices provides the attention scores between words. The softmax function then produces weights to provide some tokens with higher attention. These weights are applied on the Value matrix to update the token embeddings. 
#### Multi-Head Attention
The above self attention is applied multiple times, in parallel, to learn different semantic aspects of the sequence. The outputs from each head are concatenated and a linear layer provides the updated token embedding.
#### Feed Forward Layer
The feed forward layers map the knowledge gained in the attention stage, into non-linear representations of our data, through use of the ReLU function. A separate embedding dimension (d_ff) is used to help capture more complex patterns.

In [None]:
# Positional Encoding using the nn.Module subclass
class PositionalEncoder(nn.Module):
    def __init__(self, d_model, max_length):
        super(PositionalEncoder, self).__init__()
        self.d_model = d_model
        self.max_length = max_length        
        # Create a positional encoding matrix up to the specified max_length
        pe = torch.zeros(max_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        # Scale position indices and calculate position encodings into the matrix
        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(10000.0) / d_model))        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        # Set matrix as a non-trainable parameter
        self.register_buffer('pe', pe)
    
    # Add the positional encodings to the embeddings tensor
    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x
    
    
# Multi-Head Attention Mechanism
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        # Set the number of attention heads
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = d_model // num_heads
		# Set up the linear transformations for Q, K, V
        self.query_linear = nn.Linear(d_model, d_model)
        self.key_linear = nn.Linear(d_model, d_model)
        self.value_linear = nn.Linear(d_model, d_model)
        # Set the layer for the final concatenated output
        self.output_linear = nn.Linear(d_model, d_model)
        
    def split_heads(self, x, batch_size):
        # Split the sequence embeddings in x across the attention heads to ensure the outputs have the correct dimensions
        x = x.view(batch_size, -1, self.num_heads, self.head_dim)
        return x.permute(0, 2, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.head_dim)

    def compute_attention(self, query, key, mask=None):
        # Compute Q-K pair dot-product attention scores
        scores = torch.matmul(query, key.permute(1, 2, 0))
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-1e20"))
        # Use softmax to normalize attention scores into attention weights
        attention_weights = F.softmax(scores, dim=-1)
        return attention_weights

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        query = self.split_heads(self.query_linear(query), batch_size)
        key = self.split_heads(self.key_linear(key), batch_size)
        value = self.split_heads(self.value_linear(value), batch_size)
        # Calculate the attention weights
        attention_weights = self.compute_attention(query, key, mask)
        # Multiply attention weights by values and concatenate outputs
        output = torch.matmul(attention_weights, value)
        output = output.view(batch_size, self.num_heads, -1, self.head_dim).permute(0, 2, 1, 3).contiguous().view(batch_size, -1, self.d_model)
        # Linearly project the outputs
        return self.output_linear(output)
    
# Feed Forward layer
class FeedForwardSubLayer(nn.Module):
    # Specify the two linear layers' input and output sizes
    def __init__(self, d_model, d_ff):
        super(FeedForwardSubLayer, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        # Apply a ReLU layer for non-linearity
        self.relu = nn.ReLU()

	# Apply a forward pass
    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

# Transformer Encoder Architecture
#### Encoder Layer
The encoder layer applies the Multi-Head Attention mechanism and the Feed Forward (Sub)layer. Normalization layers are used to keep the outputs on a similar scale. Dropout layers are used to regulate the training process by reducing overfitting of our model on the training data. A mask is used to prevent processing of padding tokens when dealing with different length sequences.
#### Encoder Transformer Head
The transformer head processes the encoded inputs and produces a task specific output. In our example, we have a:
- Transformer classification head for extractive question-answering, text classification, sentiment analysis etc.
- Transformer regression head for language complexity, estimating text readability etc.

In [None]:
# Encoder layer class
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForwardSubLayer(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        return self.norm2(x + self.dropout(ff_output))

class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_sequence_length):
        super(TransformerEncoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoder(d_model, max_sequence_length)
        # Define a stack of multiple encoder layers
        self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
	
    # Apply the forward pass
    def forward(self, x, mask):
        x = self.embedding(x)
        x = self.positional_encoding(x)
        # Perform for all encoder layers
        for layer in self.layers:
            x = layer(x, mask)
        return x
    
# Transformer classification head
class ClassifierHead(nn.Module):
    def __init__(self, d_model, num_classes):
        super(ClassifierHead, self).__init__()
        # Add linear layer for multiple-class classification
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x):
        logits = self.fc(x[:, 0, :])
        # Obtain log class probabilities upon raw outputs
        return F.log_softmax(logits, dim=-1)
    
# Transformer regression head
class RegressionHead(nn.Module):
    def __init__(self, d_model, output_dim):
        super(ClassifierHead, self).__init__()
        # Add linear layer for multiple-class classification
        self.fc = nn.Linear(d_model, output_dim)

    def forward(self, x):
        return self.fc(x)

## Training and Testing the Encoder Architecture

In [None]:
train_data = torch.randint(1, train_vocab_size, (5, max_len))
target_data = torch.randint(1, target_vocab_size, (5, max_len))

model = TransformerEncoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_sequence_length)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# Train our model
for epoch in range(epoch_num): 
    # Zero the gradient before each pass
    optimizer.zero_grad()
    output = model(train_data, target_data[:, :-1])
    loss = criterion(output.contiguous().view(-1, target_vocab_size), target_data[:, 1:].contiguous().view(-1))
    # Compute loss and update parameters
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}: Loss= {loss.item():.3f}")

input_sequence = torch.randint(0, vocab_size, (batch_size, sequence_length))
mask = torch.randint(0, 2, (sequence_length, sequence_length))

# Instantiate the encoder transformer's body and head
encoder = TransformerEncoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_sequence_length)
classifier = ClassifierHead(d_model, num_classes)

# Complete the forward pass 
output = encoder(input_sequence, mask)
classification = classifier(output)
print(classification)

# Transformer Decoder Architecture
#### Decoder Layer - Masked Multi-Head Self Attention
The decoder layer is similar to the encoder layer, but it also uses Masked Multi-Head Self Attention. In this mechanism, an upper triangular mask is used, meaning that for each token, only the previously generated tokens are looked at and future tokens are not considered. This helps in predicting the next word in the sequence, one token at a time.

#### Decoder Transformer Head
A linear layer and a softmax activation function is applied, over the entire vocabulary, predicting the most likely next token. 
# Transformer Encoder-Decoder Architecture
#### Output Embeddings (Decoder Inputs)
In practice, the decoder only takes the target sequences as training data, e.g. output summarisation in summarisation tasks, next words in text generation tasks etc. When using the model, the output embedding is empty and is as the next tokens are generated.
#### Cross Attention Mechanism
This occurs in the decoder layer, after the masked attention step. Final hidden states from the encoder and information passed through the decoder are taken as inputs. The encoder outputs (y) are provided as the key and value arguments, to allow the decoder to consider the processed input sequence. The decoder (x) provides the cross-attention query, to help generate the target sequence.
#### Encoder-Decoder Transformer Head
A linear layer and a softmax activation function is applied, providing probabilities for the next-word from the decoder output. Other activations may be used instead of softmax depending on the task.

In [None]:
# Decoder layer class
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()     
        # Initialize the causal (masked) self-attention and cross-attention
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads) # Not required in decoder only
        self.feed_forward = FeedForwardSubLayer(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, causal_mask, encoder_output, cross_mask):
        # Pass the necessary arguments to the causal self-attention and cross-attention
        self_attn_output = self.self_attn(x, x, x, causal_mask)
        x = self.norm1(x + self.dropout(self_attn_output))
        cross_attn_output = self.cross_attn(x, encoder_output, encoder_output, cross_mask)
        x = self.norm2(x + self.dropout(cross_attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_sequence_length):
        super(TransformerDecoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_sequence_length)
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        # Apply the transformer head for next-word prediction inside this class
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x, self_mask):
        x = self.embedding(x)
        x = self.positional_encoding(x)
        for layer in self.layers:
            x = layer(x, self_mask)
        # Apply the forward pass through the model head
        x = self.fc(x)
        return F.log_softmax(x, dim=-1)

## Testing the Decoder

In [None]:
input_sequence = torch.randint(0, vocab_size, (batch_size, sequence_length))

# Create a triangular attention mask for causal attention by applying an upper triangular matrix
self_attention_mask = (1 - torch.triu(torch.ones(1, sequence_length, sequence_length), diagonal=1)).bool()

# Instantiate the decoder transformer
decoder = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_sequence_length=sequence_length)

output = decoder(input_sequence, self_attention_mask)
print(output.shape)
print(output)

## Testing the Encoder-Decoder Transformer

In [None]:
# Create a batch of random input sequences
input_sequence = torch.randint(0, vocab_size, (batch_size, sequence_length))
padding_mask = torch.randint(0, 2, (sequence_length, sequence_length))
causal_mask = torch.triu(torch.ones(sequence_length, sequence_length), diagonal=1)

# Instantiate the two transformer bodies
encoder = TransformerEncoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_sequence_length=sequence_length)
decoder = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_sequence_length=sequence_length)

# Pass the necessary masks as arguments to the encoder and the decoder
encoder_output = encoder(input_sequence, padding_mask)
decoder_output = decoder(input_sequence, causal_mask, encoder_output, padding_mask)
print("Batch's output shape: ", decoder_output.shape)