In [None]:
import torch
import torch.nn as nn

class FIRE(nn.Module):
    def __init__(self, num_heads=12, mlp_width=32, init_c=0.1,
                 init_L=512., eps=1e-6):
        """
        FIRE attention bias module.

        Args:
            num_heads: number of attention heads.
            mlp_width: Width of MLP.
            init_c: initial value of log transformation parameter.
            init_L: initial value of thresholding parameter.
            eps: small constant for numerical stability.
        """
        super(FIRE, self).__init__()

        # Define the MLP layers
        self.mlp = nn.Sequential(
            nn.Linear(1, mlp_width),
            nn.ReLU(),
            nn.Linear(mlp_width, num_heads)
        )

        # Initialize c (log transformation parameter)
        self.c = nn.Parameter(torch.tensor(init_c), requires_grad=False)

        # Initialize L (threshold)
        self.init_L = nn.Parameter(torch.tensor(init_L), requires_grad=False)
        # Learn a multiplier to L
        self.L_multiplier = nn.Parameter(torch.tensor(1.0))

        self.eps = eps

    def forward(self, x: torch.Tensor):
        """
        Compute FIRE attention bias.

        Args:
            x: input sequence, shape [bsz, num_heads, seq_len, hidden_dim]

        Returns:
            attention bias, shape [1, num_heads, seq_len, seq_len]
        """
        seq_length = x.size(2)
        positions = torch.arange(seq_length, dtype=torch.float, device=x.device)
        rel_distance = positions[:, None] - positions[None, :]

        # Thresholding the normalizer
        threshold = torch.abs(self.L_multiplier * self.init_L)
        pos_normalizer = torch.max(positions, threshold)
        pos_normalizer = pos_normalizer[:, None]

        # Amplifying differences among local positions with log transform
        rel_distance = torch.log(torch.abs(self.c * rel_distance) + 1)
        pos_normalizer = torch.log(torch.abs(self.c * pos_normalizer) + 1) + self.eps

        # Progressive interpolation
        normalized_distance = rel_distance / pos_normalizer
        fire_bias = self.mlp(normalized_distance.unsqueeze(-1))
        fire_bias = fire_bias.unsqueeze(0).permute(0, 3, 1, 2)
        return fire_bias

In [None]:
class AttentionWithFIRE(nn.Module):
    def __init__(self, dim_model, num_heads):
        super(AttentionWithFIRE, self).__init__()
        self.attention = nn.MultiheadAttention(dim_model, num_heads)
        self.positional_encoding = FIRE(dim_model)

    def forward(self, queries, keys, values):
        # Get query and key positions
        query_pos = torch.arange(queries.size(0))
        key_pos = torch.arange(keys.size(0))
        
        # Compute FIRE positional encoding
        pos_encoding = self.positional_encoding(query_pos, key_pos)
        
        # Add positional encoding to attention scores
        attn_output, attn_weights = self.attention(queries + pos_encoding, keys, values)
        return attn_output, attn_weights

***

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# FIRE module as provided
class FIRE(nn.Module):
    def __init__(self, num_heads=12, mlp_width=32, init_c=0.1, init_L=512., eps=1e-6):
        super(FIRE, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(1, mlp_width),
            nn.ReLU(),
            nn.Linear(mlp_width, num_heads)
        )
        self.c = nn.Parameter(torch.tensor(init_c))
        self.init_L = nn.Parameter(torch.tensor(init_L), requires_grad=False)
        self.L_multiplier = nn.Parameter(torch.tensor(1.0))
        self.eps = eps

    def forward(self, x: torch.Tensor):
        seq_length = x.size(2)
        positions = torch.arange(seq_length, dtype=torch.float, device=x.device)
        rel_distance = positions[:, None] - positions[None, :]

        # Thresholding the normalizer
        threshold = torch.abs(self.L_multiplier * self.init_L)
        pos_normalizer = torch.max(positions, threshold)
        pos_normalizer = pos_normalizer[:, None]

        # Amplifying differences among local positions with log transform
        rel_distance = torch.log(torch.abs(self.c * rel_distance) + 1)
        pos_normalizer = torch.log(torch.abs(self.c * pos_normalizer) + 1) + self.eps

        # Progressive interpolation
        normalized_distance = rel_distance / pos_normalizer
        fire_bias = self.mlp(normalized_distance.unsqueeze(-1))
        fire_bias = fire_bias.unsqueeze(0).permute(0, 3, 1, 2)
        return fire_bias

# Multihead Attention with FIRE integration
class MultiheadAttentionWithFIRE(nn.Module):
    def __init__(self, embed_dim, num_heads, fire_params=None):
        super(MultiheadAttentionWithFIRE, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"

        # Define projections
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        # Define FIRE module
        self.fire = FIRE(**fire_params) if fire_params else FIRE(num_heads=num_heads)

    def forward(self, query, key, value, mask=None):
        bsz, seq_len, embed_dim = query.size()

        # Linear projections
        Q = self.q_proj(query).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(key).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(value).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Compute scaled dot-product attention
        scores = torch.einsum("bhqd, bhkd -> bhqk", Q, K) / self.head_dim ** 0.5

        # Apply FIRE bias
        fire_bias = self.fire(Q)
        scores = scores + fire_bias

        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))

        # Attention weights
        attn_weights = F.softmax(scores, dim=-1)

        # Compute attention output
        output = torch.einsum("bhqk, bhvd -> bhqd", attn_weights, V)
        output = output.transpose(1, 2).contiguous().view(bsz, seq_len, embed_dim)

        # Final linear projection
        return self.out_proj(output)

In [6]:
# Example parameters
embed_dim = 6
num_heads = 2
fire_params = {'num_heads': num_heads, 'mlp_width': 32}

# Instantiate the model
attention_layer = MultiheadAttentionWithFIRE(embed_dim, num_heads, fire_params)

# Dummy input
query = torch.rand(2, 5, embed_dim)  # [batch_size, seq_len, embed_dim]
key = torch.rand(2, 5, embed_dim)
value = torch.rand(2, 5, embed_dim)

# Forward pass
output = attention_layer(query, key, value)
print(output.shape)  # Should be [16, 10, 64]


torch.Size([2, 5, 6])


In [7]:
output

tensor([[[ 0.7272, -2.0209, -1.0937,  0.3331, -0.3517, -0.3415],
         [ 0.7272, -2.0209, -1.0937,  0.3331, -0.3517, -0.3415],
         [ 0.7272, -2.0209, -1.0937,  0.3331, -0.3517, -0.3415],
         [ 0.7272, -2.0209, -1.0937,  0.3331, -0.3517, -0.3415],
         [ 0.7272, -2.0209, -1.0937,  0.3331, -0.3517, -0.3415]],

        [[ 0.7759, -1.8854, -0.9876,  0.2124, -0.4803, -0.3230],
         [ 0.7759, -1.8854, -0.9876,  0.2124, -0.4803, -0.3230],
         [ 0.7759, -1.8854, -0.9876,  0.2124, -0.4803, -0.3230],
         [ 0.7759, -1.8854, -0.9876,  0.2124, -0.4803, -0.3230],
         [ 0.7759, -1.8854, -0.9876,  0.2124, -0.4803, -0.3230]]],
       grad_fn=<ViewBackward0>)

***

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ScaledDotProductAttention(nn.Module):
    def __init__(self, temperature):
        super().__init__()
        self.temperature = temperature

    def forward(self, query, key, value, mask=None):
        # Dot product of query and key (transpose last two dimensions of key)
        attn = torch.matmul(query, key.transpose(-2, -1)) / self.temperature
        
        if mask is not None:
            mask = mask.unsqueeze(1).expand_as(attn)  # Expanding mask for all heads.
            attn = attn.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(attn, dim=-1)
        output = torch.matmul(attn, value)
        return output, attn

class MultiHeadAttention(nn.Module):
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        
        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k)
        self.w_ks = nn.Linear(d_model, n_head * d_k)
        self.w_vs = nn.Linear(d_model, n_head * d_v)

        self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
        self.layer_norm = nn.LayerNorm(d_model)
        self.fc = nn.Linear(n_head * d_v, d_model)

        self.dropout = nn.Dropout(dropout)
        
        self.init_weights()

    def init_weights(self):
        nn.init.normal_(self.w_qs.weight, mean=0, std=0.02)
        nn.init.normal_(self.w_ks.weight, mean=0, std=0.02)
        nn.init.normal_(self.w_vs.weight, mean=0, std=0.02)
        nn.init.xavier_normal_(self.fc.weight)

    def forward(self, q, k, v, mask=None):
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        sz_b, len_q, _ = q.size()
        sz_b, len_k, _ = k.size()
        sz_b, len_v, _ = v.size()

        # Pass through the pre-attention projection: query, key, value are all split into multiple heads
        qs = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        ks = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        vs = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        qs, ks, vs = qs.transpose(1, 2), ks.transpose(1, 2), vs.transpose(1, 2)  # [sz_b, n_head, len_q, d_k]

        if mask is not None:
            mask = mask.unsqueeze(1)   # [sz_b, 1, 1, len_k]

        # Apply attention on all the projected vectors in batch
        outputs, attn = self.attention(qs, ks, vs, mask=mask)

        # Concatenate heads and project back to original size
        outputs = outputs.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
        outputs = self.dropout(self.fc(outputs))
        outputs += q
        outputs = self.layer_norm(outputs)

        return outputs, attn

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_width, dropout_rate=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(
            n_head=num_heads, d_model=embed_dim, 
            d_k=embed_dim // num_heads, d_v=embed_dim // num_heads, 
            dropout=dropout_rate)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Dropout(dropout_rate)
        )
        self.fire = FIRE(num_heads=num_heads, mlp_width=mlp_width)

    def forward(self, src):
        attn_bias = self.fire(src)
        
        # Ensure the mask is appropriate for multi-head attention.
        attn_bias = attn_bias.repeat(src.size(0), 1, 1, 1)  # Make sure the mask has the batch dimension where necessary
        attn_bias = (attn_bias <= 0).to(torch.float32)  # Creating the actual mask

        attn_output, _ = self.attention(src, src, src, mask=attn_bias)
        attn_output = self.norm1(attn_output)

        ff_output = self.feed_forward(attn_output)
        output = attn_output + ff_output
        output = self.norm2(output)

        return output

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_width, dropout_rate=0.1):
        super(TransformerBlock, self).__init__()
        
        # Manually implemented multi-head attention
        self.attention = MultiHeadAttention(n_head=num_heads, d_model=embed_dim, d_k=embed_dim // num_heads, d_v=embed_dim // num_heads, dropout=dropout_rate)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        # Feed-forward network as part of the Transformer block
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Dropout(dropout_rate)
        )

        # FIRE encoder for position encoding
        self.fire = FIRE(num_heads=num_heads, mlp_width=mlp_width)

    def forward(self, src):
        # Apply FIRE positional encoding as an additive bias in attention
        attn_bias = self.fire(src)
        attn_output, _ = self.attention(src, src, src, mask=attn_bias)
        attn_output = self.norm1(attn_output)  # Layer norm after addition and attention

        ff_output = self.feed_forward(attn_output)
        output = attn_output + ff_output  # Residual connection
        output = self.norm2(output)  # Layer norm after addition

        return output

In [9]:
# Define the Transformer model parameters
embed_dim = 128  # Embedding dimension
num_heads = 8    # Number of attention heads
mlp_width = 256  # Width of the MLP in the FIRE module
seq_length = 10  # Length of the input sequence
batch_size = 2   # Number of sequences in a batch

# Instantiate the TransformerBlock
transformer_block = TransformerBlock(embed_dim=embed_dim, num_heads=num_heads, mlp_width=mlp_width)

# Create a sample input tensor
# Assume input tensor shape [batch_size, seq_length, embed_dim]
# Random tensor mimicking a batch of embedded sequences
input_tensor = torch.randn(batch_size, seq_length, embed_dim)

# Pass the input through the Transformer model
output = transformer_block(input_tensor)

# Display the output
print("Output Tensor Shape:", output.shape)
print("Output Tensor:", output)

RuntimeError: expand(torch.FloatTensor{[1, 1, 1, 8, 128, 128]}, size=[2, 8, 10, 10]): the number of sizes provided (4) must be greater or equal to the number of dimensions in the tensor (6)