<a href="https://colab.research.google.com/github/Armin-Abdollahi/Transformer/blob/main/Attention_Mechanism.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### The **Soft Attention mechanism** is a type of attention that allows the model to focus on different parts of the input for each step of the output.

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SoftAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(SoftAttention, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.attention = nn.Linear(self.input_dim + self.hidden_dim, 1)

    def forward(self, encoder_outputs, hidden):
        # Concatenate the hidden state with each encoder output
        hidden = hidden.repeat(encoder_outputs.size(0), 1, 1).transpose(0, 1)
        encoder_outputs = encoder_outputs.transpose(0, 1)  # [batch_size, seq_len, features]
        merged = torch.cat((hidden, encoder_outputs), 2)

        # Apply the attention layer to get the attention scores
        attention_scores = self.attention(merged).squeeze(2)
        attention_scores = F.softmax(attention_scores, dim=1)

        # Multiply the scores by the encoder outputs to get the weighted sum
        context_vector = torch.bmm(attention_scores.unsqueeze(1), encoder_outputs).squeeze(1)
        return context_vector, attention_scores

# Example usage:
# Define the dimensions
input_dim = 128  # Size of the encoder output feature vector
hidden_dim = 256  # Size of the decoder hidden state

# Create the SoftAttention layer
attention_layer = SoftAttention(input_dim, hidden_dim)

# Assume some random encoder outputs and hidden state
encoder_outputs = torch.randn(10, 32, input_dim)  # [seq_len, batch_size, features]
hidden = torch.randn(32, hidden_dim)  # [batch_size, hidden_dim]

# Get the context vector and attention scores
context_vector, attention_scores = attention_layer(encoder_outputs, hidden)

#### **Hard attention mechanisms** selectively focus on certain parts of the input data while ignoring the rest, which is different from soft attention that considers all parts with varying weights.

In [None]:
import numpy as np

def hard_attention(query, keys, values):
    """
    Implements a simple hard attention mechanism.

    Parameters:
    query (ndarray): The query vector.
    keys (ndarray): The key vectors.
    values (ndarray): The value vectors.

    Returns:
    ndarray: The context vector after applying hard attention.
    """
    # Calculate the dot product between the query and the keys
    attention_scores = np.dot(query, keys.T)

    # Find the index of the maximum score (hard attention)
    max_index = np.argmax(attention_scores)

    # Select the value vector corresponding to the maximum score
    context_vector = values[max_index]

    return context_vector

# Example usage:
query = np.array([1, 0, 0])
keys = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
values = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

context = hard_attention(query, keys, values)
print("Context Vector:", context)

### Implement a simple **Squeeze-and-Excitation (SE) attention** module

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        # Define the squeeze operation
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # Define the excitation operations
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        # Squeeze operation
        y = self.avg_pool(x).view(b, c)
        # Excitation operation
        y = self.fc(y).view(b, c, 1, 1)
        # Scale the input with the attention weights
        return x * y.expand_as(x)

# Example usage
# Define the input tensor with batch size 2, 64 channels, and 32x32 spatial dimensions
input_tensor = torch.randn(2, 64, 32, 32)
# Create the SE attention layer for 64 channels
se_layer = SELayer(channel=64)
# Forward pass to obtain the output with attention applied
output_tensor = se_layer(input_tensor)

print(output_tensor.shape)  # Expected shape: [2, 64, 32, 32]

torch.Size([2, 64, 32, 32])


a high-level overview of how you can implement various attention mechanisms in Python, particularly useful for EEG signal processing and other machine learning tasks:

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Soft Attention
class SoftAttention(nn.Module):
    def __init__(self, input_dim, attention_dim):
        super(SoftAttention, self).__init__()
        self.attention_weights = nn.Linear(input_dim, attention_dim)
        self.context_vector = nn.Linear(attention_dim, 1, bias=False)

    def forward(self, inputs):
        # Compute attention scores
        attention_scores = F.tanh(self.attention_weights(inputs))
        attention_scores = self.context_vector(attention_scores).squeeze(2)

        # Apply softmax to get attention distribution
        attention_weights = F.softmax(attention_scores, dim=1)

        # Compute weighted sum of inputs
        weighted_sum = torch.bmm(attention_weights.unsqueeze(1), inputs).squeeze(1)
        return weighted_sum, attention_weights

# Hard Attention (Stochastic and not differentiable, usually implemented using Reinforcement Learning)

# Self-Attention (Also known as Intra-Attention)
class SelfAttention(nn.Module):
    def __init__(self, input_dim, attention_dim):
        super(SelfAttention, self).__init__()
        self.query = nn.Linear(input_dim, attention_dim)
        self.key = nn.Linear(input_dim, attention_dim)
        self.value = nn.Linear(input_dim, attention_dim)

    def forward(self, inputs):
        Q = self.query(inputs)
        K = self.key(inputs)
        V = self.value(inputs)

        # Compute scaled dot-product attention
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(K.size(-1), dtype=torch.float32))
        attention_weights = F.softmax(attention_scores, dim=-1)

        # Apply attention weights to values
        output = torch.matmul(attention_weights, V)
        return output, attention_weights

# Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, num_heads, attention_dim):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.attention_dim = attention_dim
        self.heads = nn.ModuleList([SelfAttention(input_dim, attention_dim) for _ in range(num_heads)])

    def forward(self, inputs):
        head_outputs = [head(inputs)[0] for head in self.heads]
        concatenated = torch.cat(head_outputs, dim=2)
        return concatenated

# Cross-Attention, Causal Attention, Global vs. Local Attention
# These can be implemented similarly to Self-Attention with modifications to the attention mask or the range of attention.



This code provides a starting point for implementing different attention mechanisms. For EEG signal processing, you might need to adjust the input dimensions and possibly incorporate domain-specific knowledge into the attention mechanisms. Remember, the actual implementation details can vary based on the specific requirements of your task and the architecture of your neural network.

For Cross-Attention, you would modify the self-attention mechanism to take two different inputs, one serving as queries and the other as keys and values. Causal Attention restricts the attention to only consider previous time steps, which is crucial for tasks like time-series forecasting. Global vs. Local Attention refers to whether the attention mechanism considers all parts of the input sequence (global) or only a subset (local), which can be implemented by modifying the attention mask or using convolutions to restrict the receptive field.

Please note that for Hard Attention, due to its stochastic nature, it’s often implemented using reinforcement learning techniques or with approximations that allow for backpropagation, such as the Gumbel-Softmax trick.