In [15]:
"""
this notebook contains helper functions for the decoder-only transformer architecture
"""

'\nthis notebook contains helper functions for the decoder-only transformer architecture\n'

In [9]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional

In [10]:
def sinusoidal_positional_encoding(max_len, dim_model, dim_op=3):
    """
    Creates sinusoidal position embeddings for transformer inputs. Each position is encoded
    using sine for even dimensions and cosine for odd dimensions at different frequencies.
    
    Args:
        max_len (int): Maximum sequence length to generate positions for
        dim_model (int): Size of each position encoding vector (must be even)
        dim_op (int): Target output dimension, default 3 for (batch, seq_len, dim_model)
    
    Returns:
        torch.Tensor: Position encodings of shape ([1,]*dim_op-2, max_len, dim_model)
                     Ready for broadcasting when added to input embeddings
    """
    if dim_model % 2 != 0:
        raise ValueError("dim_model must be even to split dimensions between sin/cos")
        
    # Create position and dimension indices
    positions = torch.arange(max_len, dtype=torch.float)
    dimensions = torch.arange(0, dim_model, 2, dtype=torch.float)
    
    # Calculate angle rates
    angle_rates = 1 / (10000 ** (dimensions / dim_model))
    
    # Calculate angles by outer product
    angles = positions.unsqueeze(1) * angle_rates.unsqueeze(0)
    
    # Apply sin/cos
    encodings = torch.zeros(max_len, dim_model)
    encodings[:, 0::2] = torch.sin(angles)
    encodings[:, 1::2] = torch.cos(angles)
    
    # Reshape for broadcasting
    shape = [1] * (dim_op-2) + [max_len, dim_model]
    return encodings.view(*shape)

In [11]:
def shift_attention_scores(scores):
    """
    Shifts attention scores to align relative position information along diagonals.
    Optimized implementation of the skewing mechanism for relative position scores.
    
    Args:
        scores (torch.Tensor): Attention scores tensor of shape (..., seq_len_q, seq_len_k)
    
    Returns:
        torch.Tensor: Shifted scores with relative positions aligned on diagonals
    """
    # Add padding column
    padded = F.pad(scores, [1, 0])
    
    # Reshape for diagonal alignment
    batch_dims = scores.shape[:-2]
    seq_len_q, seq_len_k = scores.shape[-2:]
    
    # Efficient reshape and slice
    shifted = padded.view(*batch_dims, seq_len_q, seq_len_k + 1)
    shifted = shifted[..., 1:]
    
    return shifted

def relative_attention(query, key, value, rel_pos_emb=None, mask=None):
    """
    Computes attention scores considering both content and relative positions.
    Implementation of relative positional attention mechanism optimized for music modeling.
    
    Args:
        query (torch.Tensor): Query matrix (..., seq_len_q, dim)
        key (torch.Tensor): Key matrix (..., seq_len_k, dim)
        value (torch.Tensor): Value matrix (..., seq_len_k, dim)
        rel_pos_emb (torch.Tensor, optional): Relative position embeddings (seq_len_k, dim)
        mask (torch.Tensor, optional): Attention mask, 1s indicate positions to mask
    
    Returns:
        torch.Tensor: Attention output of shape (..., seq_len_q, dim)
    """
    # Content-based attention scores
    content_scores = torch.matmul(query, key.transpose(-1, -2))
    
    # Add relative position scores if provided
    if rel_pos_emb is not None:
        position_scores = shift_attention_scores(
            torch.matmul(query, rel_pos_emb.transpose(-1, -2))
        )
    else:
        position_scores = torch.zeros_like(content_scores)
    
    # Combine and scale attention scores
    scale = 1 / sqrt(query.size(-1))
    attention_scores = (content_scores + position_scores) * scale
    
    # Apply mask if provided
    if mask is not None:
        attention_scores = attention_scores.masked_fill(mask == 1, -1e9)
    
    # Compute attention weights and final output
    attention_weights = F.softmax(attention_scores, dim=-1)
    return torch.matmul(attention_weights, value)

In [12]:
class MultiHeadRelativeAttention(nn.Module):
    """
    Multi-Head Attention with Relative Position Encoding optimized for music sequence modeling.
    Each head learns different musical aspects while considering relative positions between tokens.
    """
    def __init__(self, d_model: int, num_heads: int, max_rel_dist: int, dropout: float = 0.1):
        """
        Args:
            d_model: Model dimension (must be divisible by num_heads)
            num_heads: Number of attention heads
            max_rel_dist: Maximum relative distance for position encoding
            dropout: Dropout rate for attention weights
        """
        super().__init__()
        
        if d_model % num_heads != 0:
            raise ValueError(f"d_model ({d_model}) must be divisible by num_heads ({num_heads})")
            
        # Core dimensions
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.scale = self.head_dim ** -0.5  # Pre-compute attention scale
        
        # Linear projections
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)  # Combined Q,K,V projection
        self.output = nn.Linear(d_model, d_model)
        
        # Relative position embedding
        self.rel_pos_embedding = nn.Embedding(max_rel_dist, d_model)
        self.dropout = nn.Dropout(dropout)

    def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
        """
        Splits input tensor into multiple heads.
        shape: (batch, seq_len, d_model) -> (batch, num_heads, seq_len, head_dim)
        """
        batch, seq_len, _ = x.shape
        x = x.view(batch, seq_len, self.num_heads, self.head_dim)
        return x.transpose(1, 2)

    def _get_rel_pos(self, seq_len: int, device: torch.device) -> torch.Tensor:
        """
        Computes relative position embeddings for given sequence length.
        """
        pos_ids = torch.arange(seq_len, device=device)
        rel_pos = pos_ids.unsqueeze(1) - pos_ids.unsqueeze(0)
        rel_pos = rel_pos.clamp(-self.rel_pos_embedding.num_embeddings + 1, 0)
        return self.rel_pos_embedding(-rel_pos)

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Compute multi-head relative attention.
        
        Args:
            x: Input tensor of shape (batch, seq_len, d_model)
            mask: Optional attention mask
            
        Returns:
            Output tensor of shape (batch, seq_len, d_model)
        """
        batch_size, seq_len, _ = x.shape
        
        # Project Q, K, V together for efficiency
        qkv = self.qkv(x).chunk(3, dim=-1)
        q, k, v = map(self._split_heads, qkv)  # (batch, num_heads, seq_len, head_dim)
        
        # Get relative position encodings
        rel_pos = self._get_rel_pos(seq_len, x.device)
        rel_pos = self._split_heads(rel_pos).squeeze(0)  # (num_heads, seq_len, head_dim)
        
        # Compute attention scores
        content_scores = torch.matmul(q, k.transpose(-2, -1))
        position_scores = torch.matmul(q, rel_pos.transpose(-2, -1))
        position_scores = relative_attention.shift_attention_scores(position_scores)
        
        # Combine and scale attention scores
        attention_scores = (content_scores + position_scores) * self.scale
        
        # Apply mask if provided
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 1, -1e9)
        
        # Get attention weights and apply to values
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        attention = torch.matmul(attention_weights, v)
        
        # Combine heads and project output
        attention = attention.transpose(1, 2).contiguous()
        output = attention.view(batch_size, seq_len, self.d_model)
        return self.output(output)

In [13]:
import torch
import torch.nn as nn
from typing import Optional

class FeedForward(nn.Module):
    """
    Position-wise Feed Forward Network with two linear transformations and ReLU activation.
    Processes each position independently and identically.
    """
    def __init__(self, dim_model: int, dim_ff: int, dropout: float = 0.1):
        """
        Args:
            dim_model: Input and output dimension
            dim_ff: Hidden dimension of the feed-forward network
            dropout: Dropout rate
        """
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(dim_model, dim_ff),
            nn.GELU(),  # GELU typically works better than ReLU for transformers
            nn.Dropout(dropout),
            nn.Linear(dim_ff, dim_model)
        )
        
        # Initialize weights using Kaiming initialization
        for layer in self.network:
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_normal_(layer.weight)
                if layer.bias is not None:
                    nn.init.zeros_(layer.bias)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input tensor of shape (batch, seq_len, dim_model)
        Returns:
            Output tensor of shape (batch, seq_len, dim_model)
        """
        return self.network(x)

class DecoderBlock(nn.Module):
    """
    Pre-LN Transformer Decoder Block with relative positional self-attention.
    Architecture: LayerNorm -> Self-Attention -> Add & Norm -> FFN -> Add & Norm
    """
    def __init__(
        self, 
        dim_model: int,
        num_heads: int,
        dim_ff: int,
        max_rel_dist: int,
        dropout: float = 0.1,
        eps: float = 1e-6
    ):
        """
        Args:
            dim_model: Model's hidden dimension
            num_heads: Number of attention heads
            dim_ff: Feed-forward network's hidden dimension
            max_rel_dist: Maximum relative distance for positional encoding
            dropout: Dropout probability
            eps: LayerNorm epsilon
        """
        super().__init__()
        
        # Main layers
        self.self_attention = MultiHeadRelativeAttention(
            d_model=dim_model,
            num_heads=num_heads,
            max_rel_dist=max_rel_dist,
            dropout=dropout
        )
        
        self.feed_forward = FeedForward(
            dim_model=dim_model,
            dim_ff=dim_ff,
            dropout=dropout
        )
        
        # Normalization and regularization
        self.norm1 = nn.LayerNorm(dim_model, eps=eps)
        self.norm2 = nn.LayerNorm(dim_model, eps=eps)
        self.dropout = nn.Dropout(dropout)
        
        # Save dimensions for potential future use
        self.dim_model = dim_model
        self.num_heads = num_heads
        
    def forward(
        self, 
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Forward pass through decoder block.
        
        Args:
            x: Input tensor of shape (batch, seq_len, dim_model)
            mask: Optional attention mask tensor
            
        Returns:
            Output tensor of shape (batch, seq_len, dim_model)
        """
        # Self-attention block with residual connection
        normed = self.norm1(x)
        attended = self.self_attention(normed, mask=mask)
        residual1 = x + self.dropout(attended)
        
        # Feed-forward block with residual connection
        normed = self.norm2(residual1)
        transformed = self.feed_forward(normed)
        output = residual1 + self.dropout(transformed)
        
        return output

In [14]:
def create_padding_mask(sequence: torch.Tensor, pad_token: int = 0) -> torch.Tensor:
    """
    Creates mask for padding tokens in sequences.
    
    Args:
        sequence: Input tensor of shape (batch_size, seq_len)
        pad_token: Token used for padding, default 0
        
    Returns:
        Tensor of shape (batch_size, 1, 1, seq_len) where 1s indicate padding positions
    """
    # Create boolean mask for padding tokens (True where padding exists)
    mask = (sequence == pad_token)
    
    # Add dimensions for broadcasting with attention scores
    return mask.unsqueeze(1).unsqueeze(1).float()

def create_causal_mask(seq_len: int, device: Optional[torch.device] = None) -> torch.Tensor:
    """
    Creates causal mask to prevent attention to future tokens.
    
    Args:
        seq_len: Length of sequence
        device: Optional torch device for mask
        
    Returns:
        Upper triangular mask of shape (1, 1, seq_len, seq_len)
    """
    # Create upper triangular mask
    mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)
    
    # Add dimensions for broadcasting
    mask = mask.unsqueeze(0).unsqueeze(0)
    
    if device is not None:
        mask = mask.to(device)
        
    return mask

def create_combined_mask(
    sequence: torch.Tensor,
    pad_token: int = 0,
    causal: bool = True
) -> torch.Tensor:
    """
    Creates combined padding and causal mask.
    
    Args:
        sequence: Input tensor of shape (batch_size, seq_len)
        pad_token: Token used for padding
        causal: Whether to include causal masking
        
    Returns:
        Combined mask of shape (batch_size, 1, seq_len, seq_len)
    """
    # Get padding mask
    padding_mask = create_padding_mask(sequence, pad_token)
    
    if not causal:
        return padding_mask
    
    # Get causal mask and combine
    causal_mask = create_causal_mask(sequence.size(-1), sequence.device)
    return torch.maximum(padding_mask, causal_mask)

def apply_mask(
    attention_scores: torch.Tensor,
    mask: torch.Tensor,
    fill_value: float = -1e9
) -> torch.Tensor:
    """
    Applies mask to attention scores.
    
    Args:
        attention_scores: Raw attention scores
        mask: Boolean mask (1s indicate positions to mask)
        fill_value: Value to use for masked positions
        
    Returns:
        Masked attention scores
    """
    return attention_scores.masked_fill(mask == 1, fill_value)