# SparseAttention
https://github.com/kyegomez/SparseAttention


blocksparse_attention_impl
-----------------------------------------------------------------------
Transpose for Strided Attention: -> strided_transpose()

1. If attn_mode == 'strided', the function first applies strided_transpose to tensors q, k, and v. This function reshapes and transposes the tensors to prepare them for a strided attention pattern. This step reorganizes the tensors into blocks and local contexts, as discussed earlier.
Compute Attention Weights:

2. Calculate attention weights w using matrix multiplication of q and k, followed by softmax normalization. This step computes how much each token (query) should attend to each other token (key).


3. Compute the attended output a by multiplying the attention weights w with the value tensor v.
Reverse Transpose (if Strided Attention):


Purpose:
The purpose of strided_transpose is to rearrange the tensor x to facilitate a specific attention pattern
known as strided attention. In strided attention, attention is applied sparsely across the sequence by skipping
certain tokens, hence reducing the computational complexity while still capturing long-range dependencies.

strided_transpose
-----------------------------------------------
1. bT_ctx is computed as n_ctx // local_attn_ctx. This determines how many blocks (local_attn_ctx) fit into the
total context length (n_ctx).

2. x = torch.reshape(x, [n, bT_ctx, local_attn_ctx, embd]): Reshapes the tensor x to split the sequence (t) into
blocks (bT_ctx) of local context size (local_attn_ctx).
Transpose dimensions:

3. x = torch.transpose(x, 0, 2, 1, 3): Transposes the tensor dimensions. This step likely reorders the tensor to
prepare it for the desired strided attention pattern.

4. x = torch.reshape(x, [n, t, embd]): Reshapes the tensor back to its original shape after applying the desired
transposition.

Usage in SparseAttention (atten_mode = "stride"):
In the context of SparseAttention with atten_mode = "stride", this function would typically be used to
preprocess the input tensor x before performing the actual attention computation. Strided attention patterns are
beneficial when dealing with long sequences because they reduce the computational cost of attending to all tokens,
while still allowing tokens to attend to each other at varying distances.


get_attn_mask
---------------------------------------------------------------------------------------------
n: Number of time steps or sequence length, determining the size of the mask.
attn_mode: Specifies the type of attention mask to generate ('all', 'local', 'strided').


1. 'all' Mode: Creates a lower triangular matrix where all positions below the main diagonal
	are set to 1 (torch.tril(torch.ones([n, n]))).
2. 'local' Mode: Limits the attention range to a specified local_attn_ctx by setting all positions beyond
	ctx (bandwidth) below the diagonal to 0.
3. 'strided' Mode: Implements a strided attention pattern:
	1. Constructs tensors q and k using torch.arange and torch.transpose to create matrices of indices.
	2. Checks conditions (c1, c2) to determine which elements should be attended to based on a stride
		condition (local_attn_ctx).
	3. Combines conditions to form a binary mask b where 1 indicates allowed attention and 0 indicates
		forbidden attention.
4. Reshaping: Reshapes the mask tensor b to [1, 1, n, n] to align with the dimensions expected by the attention mechanism.

attention_impl
-------------------------------------------------------------------------------------------
1. Split Heads: Splits the query, key, and value tensors into multiple heads using the split_heads
function : (batch, pixel, state) -> (batch, pixel, head, head_state) , here m= head_state = head//x.size()[-1]

2. Get Mask: Generates the attention mask (mask) using get_attn_mask.

3. Compute Attention Weights (w):
	1. Calculates the attention scores by computing the dot product of queries and keys
		(torch.matmul(q, k.transpose(-2, -1))).
	2. Scales the scores by scale_amount (typically 1 / sqrt(d_k) where d_k is the dimension of queries or keys).
	3. Applies the attention mask (w = w * mask + -1e9 * (1 - mask)) to zero out forbidden attention weights
		(ensures softmax does not attend to masked positions).
	4. Applies softmax to compute normalized attention weights (F.softmax(w, dim=-1)).
4. Compute Attended Output (a):
	1. Uses the attention weights w to compute the weighted sum of values (torch.matmul(w, v)).
	2. Merge Heads: Merges the output of multiple heads back into a single tensor using the merge_heads
		function (not provided).

# FlashAttention

https://github.com/kyegomez/FlashAttention20



### `FlashAttentionFunction` Class (Custom Autograd Function)

### Purpose

The `forward` method computes the forward pass of the FlashAttention mechanism. It performs attention calculation between query (`q`), key (`k`), and value (`v`) tensors, handles chunking of these tensors to manage memory efficiently, applies attention masks, supports causal attention, and computes output tensors (`o`) along with necessary intermediate values.

### Breakdown of the `forward` Method


- **Parameters:**
  - `q`, `k`, `v`: Query, key, and value tensors.
  - `mask`: Attention mask.
  - `causal`: Boolean flag indicating if the attention is causal (only attends to previous positions).
  - `q_bucket_size`, `k_bucket_size`: Sizes of query and key buckets for chunking.

1. **Initialization and Setup:**
   - **Device and Data Type Handling:** Determines the device of the input query tensor `q` (`device = q.device`) and computes the maximum negative value (`max_neg_value`) for masking.
   - **Dimension and Scaling:** Computes the length difference between `k` and `q` (`qk_len_diff`) to adjust for variable lengths in attention matrices. Sets up tensors (`o`, `all_row_sums`, `all_row_maxes`) for output, row-wise sums, and maximum values.

2. **Mask Handling:**
   - **Existence Check and Rearrangement:** Checks if `mask` exists and reshapes it if it has 2 dimensions (`mask.ndim == 2`). If `mask` does not exist, initializes `col_masks` and repeats them based on `num_row_tiles` and `num_col_tiles`.

3. **Chunking and Iteration:**
   - **Row and Column Splits:** Divides the query (`q`), output (`o`), mask (`row_mask`), row sums (`row_sums`), and row maxes (`row_maxes`) into chunks (`q.split(...)`, `o.split(...)`, `mask`, etc.).
   - **Nested Loops:** Iterates over each chunk of queries (`qc`) and corresponding masks (`row_mask`), then within each chunk, iterates over chunks of keys (`kc`) and values (`vc`).

4. **Attention Calculation:**
   - **Matrix Multiplication and Scaling:** Computes attention weights (`attn_weights`) using Einstein summation (`einsum`) of queries (`qc`) and keys (`kc`), scaled by `scale`.

5. **Mask Application:**
   - **Column Masking:** Applies column-wise mask (`col_mask`) to attention weights (`attn_weights`), filling masked positions with `max_neg_value` to ignore those positions during softmax calculation.

6. **Causal Attention Handling:**
   - **Causal Masking:** If `causal` is `True`, applies a causal mask (`causal_mask`) to `attn_weights` to ensure the model only attends to previous positions.

7. **Normalization and Update:**
   - **Row-wise Operations:** Computes block-wise maximum (`block_row_maxes`) and updates (`new_row_maxes`), exponentiates attention weights (`exp_weights`), and sums them (`block_row_sums`). Computes exponential row differences (`exp_row_max_diff`), updates row sums (`new_row_sums`), and updates output (`oc`) accordingly.

8. **Normalization of Output:**
   - **Output Normalization:** Normalizes `oc` by dividing by `row_sums` to get the final attention outputs for each chunk.

9. **Log-Sum-Exp Calculation:**
   - **Log-Sum-Exp (lse):** Computes log-sum-exp (`lse`) of `all_row_sums` and `all_row_maxes`, preparing them for gradient calculations.

10. **Context Management:**
    - **Context Setup:** Stores necessary parameters and tensors (`causal`, `scale`, `mask`, `q_bucket_size`, `k_bucket_size`, `q`, `k`, `v`, `o`, `lse`) in the context (`ctx`) for use during the backward pass (`backward` method).

11. **Return:**
    - **Return Output:** Returns the computed output tensor `o` representing the attended values.

### Summary

The `forward` method efficiently computes the FlashAttention mechanism's forward pass, handling large tensors by chunking them (`q_bucket_size`, `k_bucket_size`), applying attention masks (`mask`), supporting causal attention (`causal`), and computing outputs (`o`) while managing memory and computational efficiency. It utilizes PyTorch tensor operations (`einsum`, `masked_fill_`, `clamp`, `exp`, etc.) and context management (`ctx`) for autograd to enable gradient computation in the `backward` pass. This approach is crucial for handling large-scale datasets and memory-intensive operations in deep learning models effectively.


Sure, let's break down the `backward` method of the `FlashAttentionFunction` class. This method implements the backward pass for the FlashAttention mechanism using PyTorch's custom autograd function.

### Purpose

The `backward` method computes the gradients of the inputs (`q`, `k`, `v`) with respect to the output gradients (`do`). It utilizes the saved tensors (`q`, `k`, `v`, `o`, `lse`) and other context parameters (`causal`, `scale`, `mask`, `q_bucket_size`, `k_bucket_size`) from the forward pass to efficiently compute the gradients using backpropagation.

### Breakdown of the `backward` Method


### Detailed Explanation

1. **Initialization and Setup:**
   - **Context Retrieval:** Retrieves saved tensors (`q`, `k`, `v`, `o`, `lse`) and context parameters (`causal`, `scale`, `mask`, `q_bucket_size`, `k_bucket_size`) from the context (`ctx`).

2. **Tensor Initialization:**
   - **Gradient Initialization:** Initializes tensors `dq`, `dk`, and `dv` to zeros with the same shape as `q`, `k`, and `v`, respectively, to accumulate gradients.

3. **Chunking and Iteration:**
   - **Row Splits:** Splits `q`, `o`, `do`, `mask`, `lse`, and `dq` into chunks (`q.split(...)`, `o.split(...)`, etc.) for memory efficiency during computation.

4. **Nested Loops for Gradient Computation:**
   - **Column Splits:** Iterates over chunks of `k`, `v`, `dk`, `dv`, and `row_mask`.
   - **Attention Calculation:** Computes attention weights (`attn_weights`) using `qc` (query chunk) and `kc` (key chunk) scaled by `scale`.

5. **Mask Application:**
   - **Causal Masking:** If `causal` is `True`, applies a causal mask (`causal_mask`) to `attn_weights` to mask out future positions.

6. **Gradient Calculation:**
   - **Probability and Gradient Calculation:** Computes probabilities (`p`) from `attn_weights` and `lse`, masks `p` with `col_mask`, and computes gradients (`dv_chunk`, `dp`, `ds`) based on attention weights, output gradients (`doc`), and values (`vc`).

7. **Backpropagation Update:**
   - **Gradient Accumulation:** Accumulates gradients (`dq_chunk`, `dk_chunk`, `dv_chunk`) computed for each chunk of `qc`, `kc`, and `vc` into `dqc`, `dkc`, and `dvc`, respectively.

8. **Return Gradients:**
   - **Return:** Returns computed gradients `dq`, `dk`, `dv`, and `None` for additional `ctx` arguments that don't require gradients.

### Summary

The `backward` method efficiently computes gradients for the FlashAttention mechanism using backpropagation. It iterates over chunks of tensors (`q`, `k`, `v`, `o`, `lse`) to manage memory usage, applies masks (`mask`, `causal_mask`) to attention weights (`attn_weights`), computes gradients (`dq`, `dk`, `dv`) based on output gradients (`do`), and accumulates them into corresponding tensors. This method leverages PyTorch's tensor operations (`einsum`, `masked_fill_`, etc.) and context management (`ctx`) for efficient gradient computation in deep learning models.



### `FlashAttention` Class (Module Wrapper)

This class encapsulates the FlashAttention mechanism as a PyTorch module, providing an easy-to-use interface for integration into neural network architectures.

- **Initialization (`__init__` method):**
  - Sets up parameters for the FlashAttention mechanism, including number of heads, dimensions, bucket sizes, etc.
  - Defines linear transformations (`to_q`, `to_kv`, `to_out`) for queries, keys, and values.
  - Optional configurations for parallel execution (`parallel`) and mixed precision training (`mixed_precision`).

- **Forward Pass (`forward` method):**
  - Handles the forward pass of the attention mechanism.
  - Splits input data into chunks based on `q_bucket_size`.
  - Optionally parallelizes computation across multiple GPUs.
  - Optionally applies mixed precision for faster computation.
  - Utilizes `FlashAttentionFunction.apply` to compute attention and rearranges outputs.

### Key Concepts

- **Chunking:** Divides input tensors (`q`, `k`, `v`) into smaller chunks (`q_bucket_size`, `k_bucket_size`) to manage memory and computation efficiently.
- **Causal Attention:** Supports causal attention by masking future positions during computation.
- **Memory Efficiency:** Uses incremental updates and chunk-wise computations to handle large-scale data efficiently.
- **Custom Autograd:** Implements custom forward and backward methods using `torch.autograd.Function` to define the FlashAttention mechanism.

### Usage

To use this implementation:

```python
# Instantiate FlashAttention module
flash_attn = FlashAttention(dim=512, heads=8, dim_head=64, causal=True)

# Example usage
x = torch.randn(32, 100, 512)  # Batch size of 32, sequence length of 100, dimension of 512
output = flash_attn(x)
```

This setup allows you to integrate the FlashAttention mechanism into your neural network architectures, particularly useful for handling large-scale datasets or scenarios requiring efficient attention mechanisms. Adjust parameters (`q_bucket_size`, `k_bucket_size`, etc.) as needed based on your specific application requirements.

# Replacing existing Attention Layer with Custom Attention Layer

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import OPTForCausalLM, OPTConfig

class SparseAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size):
        super(SparseAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.head_dim = embed_dim // num_heads

        assert self.head_dim * num_heads == embed_dim, "Embedding dimension must be divisible by num_heads"
        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)
        self.out_linear = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_length, embed_dim = x.size()
        assert embed_dim == self.embed_dim, "Input embedding dimension must match the module's embedding dimension"

        # Linear projections
        q = self.q_linear(x)  # (batch_size, seq_length, embed_dim)
        k = self.k_linear(x)  # (batch_size, seq_length, embed_dim)
        v = self.v_linear(x)  # (batch_size, seq_length, embed_dim)

        # Reshape for multi-head attention
        q = q.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_length, head_dim)
        k = k.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_length, head_dim)
        v = v.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_length, head_dim)

        # Compute attention scores
        attn_scores = torch.matmul(q, k.transpose(-2, -1))  # (batch_size, num_heads, seq_length, seq_length)

        # Apply local (sparse) attention mask
        mask = self._create_local_attention_mask(seq_length)
        attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
        attn_probs = F.softmax(attn_scores, dim=-1)  # (batch_size, num_heads, seq_length, seq_length)

        # Compute attention output
        attn_output = torch.matmul(attn_probs, v)  # (batch_size, num_heads, seq_length, head_dim)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, embed_dim)  # (batch_size, seq_length, embed_dim)

        # Final linear projection
        output = self.out_linear(attn_output)  # (batch_size, seq_length, embed_dim)
        return output

    def _create_local_attention_mask(self, seq_length):
        mask = torch.zeros(seq_length, seq_length)
        for i in range(seq_length):
            start = max(0, i - self.window_size // 2)
            end = min(seq_length, i + self.window_size // 2 + 1)
            mask[start:end, i] = 1
        return mask.unsqueeze(0).unsqueeze(0).to(torch.bool)  # (1, 1, seq_length, seq_length)


In [None]:
class CustomOPTModel(OPTForCausalLM):
    def __init__(self, config):
        super(CustomOPTModel, self).__init__(config)
        self._replace_attention_layers()

    def _replace_attention_layers(self):
        for layer_name, layer_module in self.named_modules():
            if isinstance(layer_module, nn.MultiheadAttention):
                # Replace the self-attention layer with SparseAttention
                setattr(self, layer_name, SparseAttention(
                    embed_dim=layer_module.embed_dim,
                    num_heads=layer_module.num_heads,
                    window_size=3  # Example window size
                ))


In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

# Load tokenizer appropriate for your model
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')

# Example input
inputs = tokenizer("Hello, my dog is cute", return_tensors='pt')

# Load pre-trained model configuration and model
config = AutoConfig.from_pretrained('facebook/opt-350m')
model = AutoModelForCausalLM.from_pretrained('facebook/opt-350m')

# Example usage
outputs = model(**inputs)
print(inputs.items())
print(inputs.tokens())
print(outputs.logits.shape)


dict_items([('input_ids', tensor([[    2, 31414,     6,   127,  2335,    16, 11962]])), ('attention_mask', tensor([[1, 1, 1, 1, 1, 1, 1]]))])
['</s>', 'Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute']
torch.Size([1, 7, 50272])



**Model Structure**: The code assumes the model has a modular structure, typically found in transformer models.

**Layer Access**: Accessing and replacing layers depends on the model’s architecture. For instance, in BERT, you access encoder layers through model.encoder.layer, while in GPT or OPT, you might need to adjust according to the specific layer structure.

**Modification Scope**: This method can be adapted to replace other components like feed-forward layers, normalization layers, etc.

**Model-Specific Adjustments**: The exact implementation may vary based on the model. For instance, GPT models use transformer.h instead of encoder.layer

# General pattern for Layer replacement

In [None]:
# import torch
# import torch.nn as nn
# from transformers import BertModel

# class CustomAttention(nn.Module):
#     def __init__(self, embed_dim, num_heads):
#         super(CustomAttention, self).__init__()
#         self.multihead_attention = nn.MultiheadAttention(embed_dim, num_heads)

#     def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None,
#                 encoder_attention_mask=None, past_key_value=None, output_attentions=False):
#         # Use the multihead_attention module to perform the attention operation
#         return self.multihead_attention(hidden_states, hidden_states, hidden_states,
#                                         attn_mask=attention_mask, key_padding_mask=attention_mask)[0]


## Custom LayerNorm + Residual connection + Attention

In [None]:
import torch
import torch.nn as nn
from transformers import BertModel

class CustomLayerNorm(nn.LayerNorm):
    def __init__(self, normalized_shape, eps=1e-12):
        super(CustomLayerNorm, self).__init__(normalized_shape, eps=eps)

    def forward(self, x):
        # Custom behavior, if needed, can be added here
        return super(CustomLayerNorm, self).forward(x)


class CustomResidualConnection(nn.Module):
    def __init__(self, dropout_prob=0.1):
        super(CustomResidualConnection, self).__init__()
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x, residual):
        # Custom residual connection implementation
        return x + self.dropout(residual)


class CustomAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(CustomAttention, self).__init__()
        self.multihead_attention = nn.MultiheadAttention(embed_dim, num_heads)
        self.layer_norm = CustomLayerNorm(normalized_shape=embed_dim)
        self.residual_connection = CustomResidualConnection()

    def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None,
                encoder_attention_mask=None, past_key_value=None, output_attentions=False):
        # Perform attention operation
        attention_output, _ = self.multihead_attention(
            hidden_states, hidden_states, hidden_states,
            attn_mask=attention_mask, key_padding_mask=attention_mask
        )

        # Apply layer normalization and residual connection
        normalized_output = self.layer_norm(attention_output)
        residual_output = self.residual_connection(normalized_output, hidden_states)

        # Return as a tuple
        return (residual_output,)


## Integration of Custom Layers

In [None]:
from transformers import AutoModel, AutoConfig

class CustomModel(nn.Module):
    def __init__(self, model_name):
        super(CustomModel, self).__init__()
        self.model = AutoModel.from_pretrained(model_name)
        self._replace_attention_layers()

    def _replace_attention_layers(self):
        for layer_name, layer_module in self.named_modules():
            if isinstance(layer_module, nn.MultiheadAttention):
                # Replace the self-attention layer
                embed_dim = layer_module.embed_dim
                num_heads = layer_module.num_heads
                setattr(self, layer_name, CustomAttention(embed_dim, num_heads))
                print(f"Replaced attention layer {layer_name}")

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None,
                head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None,
                past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None,
                return_dict=None):
        # Ensure the forward method matches the input signature of BERT
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )
        return outputs

In [None]:
# Example usage
model_name = 'bert-base-uncased'
custom_model = CustomModel(model_name)

# Example input (adjust according to the model's tokenizer)
input_ids = torch.tensor([[101, 1045, 2064, 1005, 1055, 1037, 1000, 102]], dtype=torch.long)
attention_mask = torch.ones_like(input_ids)

outputs = custom_model(input_ids=input_ids, attention_mask=attention_mask)
print(outputs.keys())
print(outputs.last_hidden_state.shape)  # Example output shape


odict_keys(['last_hidden_state', 'pooler_output'])
torch.Size([1, 8, 768])


# Local Attention & Global Attention

https://github.com/lucidrains/local-attention




Local attention and global attention are two strategies used in attention mechanisms, especially in the context of sequence modeling tasks such as machine translation, text generation, and image captioning. Here’s an explanation of each:

### Local Attention

**Definition:**
Local attention focuses only on a subset of the entire input sequence or image at each step of the attention mechanism. Instead of attending to all positions or regions globally, it restricts the attention to a fixed-size window or neighborhood around the current position. This window can be centered around the position of interest and is often defined by a predefined window size or radius.

**Characteristics:**
- **Computationally Efficient:** Since local attention only considers a limited number of positions or regions, it reduces the computational cost compared to global attention, especially for long sequences or large images.
  
- **Contextual Relevance:** By focusing on a local neighborhood, local attention can potentially capture more relevant information that is closer in proximity to the current position or region.

- **Fixed Context Window:** The size of the window or neighborhood is fixed and does not change dynamically based on the input sequence or task. This fixed context may limit the model’s ability to capture long-range dependencies or context that extends beyond the defined window.

**Applications:**
Local attention is commonly used in tasks where the input sequences are long and maintaining computational efficiency is crucial. For example:
- Language modeling and text generation where the model needs to focus on nearby words for coherence and fluency.
- Image processing tasks where the model needs to attend to neighboring pixels rather than the entire image for feature extraction or segmentation.

### Global Attention

**Definition:**
Global attention, also known as full attention or unrestricted attention, allows the model to attend to all positions or regions across the entire input sequence or image simultaneously. Unlike local attention, there are no restrictions or fixed-size windows dictating where the model can attend.

**Characteristics:**
- **Comprehensive Context:** Global attention considers all positions or regions, providing the model with a comprehensive view of the entire input sequence or image. This helps capture long-range dependencies and context that spans the entire input.

- **Higher Computational Cost:** Because global attention attends to all positions or regions, it is computationally more expensive compared to local attention, especially for large inputs.

- **Dynamic Relevance:** The relevance of each position or region is dynamically determined based on the attention weights computed during the attention mechanism. This dynamic relevance helps the model adaptively focus on important parts of the input.

**Applications:**
Global attention is suitable for tasks where capturing long-range dependencies and maintaining a comprehensive understanding of the input is critical. For example:
- Machine translation where the model needs to align words or tokens from the source and target languages across the entire sentence.
- Image classification or object detection tasks where the model needs to consider all parts of the image for recognizing objects or patterns.

### Comparison

- **Scope of Attention:** Local attention focuses on a limited neighborhood or window around the current position, whereas global attention attends to all positions or regions across the entire input.
  
- **Computational Efficiency:** Local attention is more computationally efficient due to its restricted scope, whereas global attention is more resource-intensive but captures more comprehensive context.
  
- **Suitability:** The choice between local and global attention depends on the task requirements, input size, and computational constraints. Local attention is favored for efficiency and tasks with local dependencies, while global attention is preferred for tasks requiring a broader context and long-range dependencies.

In practice, some models may use hybrid approaches that combine aspects of both local and global attention to balance efficiency and context capture, such as incorporating global attention for overall context and local attention for fine-grained details.


1. **LocalAttention Class**: This class defines a custom local attention layer. It computes attention scores using query, key, and value projections similar to standard attention mechanisms but applies a local window mask to restrict attention to a fixed window size.

2. **CustomTransformer Class**: This class integrates both `LocalAttention` and `MultiheadAttention` (for global attention). It combines the outputs of these attention mechanisms and passes them through a feedforward neural network for final processing.

3. **Example Usage**: In the `__main__` block:
   - We instantiate `CustomTransformer` with specified dimensions (`embed_dim`, `num_heads`) and `window_size`.
   - Generate a random dummy input tensor (`input_tensor`).
   - Perform a forward pass through the model (`model`) and print the shapes of the input and output tensors.

### Notes:
- **LocalAttention**: The `LocalAttention` layer in this example applies a simple left-aligned mask to restrict attention within a local window size (`window_size`). This can be further customized based on specific requirements or applications.
  
- **CustomTransformer**: Integrates both local and global attention mechanisms in a simple feedforward neural network architecture. In practice, depending on the task and requirements, you might adjust the specifics of each attention layer, such as window size for local attention or the number of heads for global attention.

This example provides a basic framework for understanding how local and global attention mechanisms can be integrated within a transformer-based model in PyTorch. Adjustments and enhancements can be made based on specific use cases or tasks.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Custom Local Attention Layer
class LocalAttention(nn.Module):
    def __init__(self, embed_dim, window_size=5):
        super(LocalAttention, self).__init__()
        self.embed_dim = embed_dim
        self.window_size = window_size
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.size()
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

#         attention_scores = torch.matmul(q, k.transpose(-1, -2)) / (self.embed_dim ** 0.5)
        print(attention_scores.shape)
        # Apply local window mask
        mask = torch.zeros_like(attention_scores)
        mask[:, :, :self.window_size] = -float('inf')
        print(mask.shape)
        attention_scores = attention_scores + mask

        attention_weights = F.softmax(attention_scores, dim=-1)
        context = torch.matmul(attention_weights, v)

        return context

# Custom Transformer Model using both Local and Global Attention
class CustomTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size=5):
        super(CustomTransformer, self).__init__()
        self.local_attention = LocalAttention(embed_dim, window_size)
        self.global_attention = nn.MultiheadAttention(embed_dim, num_heads)
        self.feedforward = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.ReLU(),
            nn.Linear(4 * embed_dim, embed_dim)
        )

    def forward(self, x):
        # Local Attention
        local_context = self.local_attention(x)

        # Global Attention
        global_context, _ = self.global_attention(x, x, x)

        # Combine local and global contexts
        combined_context = local_context + global_context

        # Feedforward layer
        output = self.feedforward(combined_context)

        return output

# Example usage
if __name__ == "__main__":
    embed_dim = 16
    num_heads = 4
    window_size = 5
    batch_size = 8
    seq_len = 10

    # Create an instance of the custom transformer model
    model = CustomTransformer(embed_dim, num_heads, window_size)

    # Generate a dummy input tensor
    input_tensor = torch.randn(batch_size, seq_len, embed_dim)

    # Forward pass through the model
    output_tensor = model(input_tensor)

    print("Input Tensor Shape:", input_tensor.shape)
    print("Output Tensor Shape:", output_tensor.shape)


torch.Size([8, 10, 10])
torch.Size([8, 10, 10])
Input Tensor Shape: torch.Size([8, 10, 16])
Output Tensor Shape: torch.Size([8, 10, 16])


In [None]:
#

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer

# class SparseAttention(nn.Module):
#     def __init__(self, input_dim, embed_dim, num_heads, window):
#         super(SparseAttention, self).__init__()
#         self.head_dim = embed_dim // num_heads
#         self.num_heads = num_heads
#         self.window = window
#         self.q_proj = nn.Linear(input_dim, embed_dim)
#         self.k_proj = nn.Linear(input_dim, embed_dim)
#         self.v_proj = nn.Linear(input_dim, embed_dim)
#         self.output_proj = nn.Linear(embed_dim, input_dim)

#     def _create_local_attention_mask(self, seq):
#         mask = torch.zeros(seq, seq)  # [seq, seq]
#         for i in range(seq):
#             start = max(0, i - self.window // 2)
#             end = min(seq, i + self.window // 2 + 1)
#             mask[start:end, i] = 1
#         mask = mask.unsqueeze(0).unsqueeze(0).to(torch.bool)  # [1, 1, seq, seq]
#         return mask

#     def forward(self, x, local_attn=False):
#         batch, seq, embed_dim = x.size()
#         q = self.q_proj(x)
#         k = self.k_proj(x)
#         v = self.v_proj(x)

#         q = q.view(batch, seq, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # [batch,n_heads,seq,head_dim]
#         k = k.view(batch, seq, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
#         v = v.view(batch, seq, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

#         attn_scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / (self.head_dim ** 0.5) # [batch,n_heads,q_seq,k_seq]

#         if local_attn:
#             mask = self._create_local_attention_mask(seq)
#             attn_scores = attn_scores.masked_fill(~mask, float('-inf'))

#         attn_weights = torch.softmax(attn_scores, dim=-1)
#         attn_output = torch.matmul(attn_weights, v) # [batch,n_heads,q_seq,head_dim]

#         attn_output = attn_output.permute(0, 2, 1, 3).contiguous().view(batch, seq, -1) # [batch,q_seq,input_dim]
#         attn_output = self.output_proj(attn_output)

#         return attn_output, attn_weights

class SparseAttention(nn.Module):
    def __init__(self, input_dim, embed_dim, num_heads, window):
        super(SparseAttention, self).__init__()
        self.head_dim = embed_dim // num_heads
        self.num_heads = num_heads
        self.window = window
        self.q_proj = nn.Linear(input_dim, embed_dim)
        self.k_proj = nn.Linear(input_dim, embed_dim)
        self.v_proj = nn.Linear(input_dim, embed_dim)
        self.output_proj = nn.Linear(embed_dim, input_dim)

    def _create_local_attention_mask(self, seq):
        mask = torch.zeros(seq, seq, dtype=torch.bool)  # [seq, seq]
        for i in range(seq):
            start = max(0, i - self.window // 2)
            end = min(seq, i + self.window // 2 + 1)
            mask[start:end, i] = 1
        mask = mask.unsqueeze(0).unsqueeze(0)  # [1, 1, seq, seq]
        return mask

    def forward(self, x, local_attn=False):
        batch_size, seq_len, embed_dim = x.size() # [batch,seq_len,head_dim * num_heads]
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # Reshape and transpose for multi-head attention
        q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # [batch, num_heads, seq_len, head_dim]
        k = k.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # [batch, num_heads, seq_len, head_dim]
        v = v.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # [batch, num_heads, seq_len, head_dim]

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)  # [batch, num_heads, seq_len, seq_len]

        if local_attn:
            mask = self._create_local_attention_mask(seq_len)
            attn_scores = attn_scores.masked_fill(mask, float('-inf'))

        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)  # [batch, num_heads, seq_len, head_dim]

        attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)  # [batch, seq_len, embed_dim]
        attn_output = self.output_proj(attn_output)

        return attn_output, attn_weights

class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.model = AutoModelForCausalLM.from_pretrained('facebook/opt-350m')
        self._replace_attention_layer() # can add more custom layers like this

    def _replace_attention_layer(self):
        for name, module in self.model.named_modules():
            if isinstance(module, nn.MultiheadAttention):
                # Extract parameters from the existing module
                input_dim = module.embed_dim
                num_heads = module.num_heads
                # Create and replace the attention layer
                new_attention_layer = SparseAttention(
                    input_dim=input_dim,
                    embed_dim=input_dim,
                    num_heads=num_heads,
                    window=4,  # Example window size

                )
                setattr(self.model, name, new_attention_layer)

    def forward(self,*args, local_attn=False, **kwargs):
        return self.model(*args,**kwargs)


# Tokenization and Model Execution
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
inputs = tokenizer("If you modify the model",return_tensors='pt')
inputs.items()

dict_items([('input_ids', tensor([[    2,  1106,    47, 23209,     5,  1421]])), ('attention_mask', tensor([[1, 1, 1, 1, 1, 1]]))])

In [None]:
inputs["input_ids"][0].shape[0]

6

In [None]:
def create_local_attention_mask(seq,window = 3):
        mask = torch.zeros(seq, seq, dtype=torch.bool)  # [seq, seq]
        for i in range(seq):
            start = max(0, i - window // 2)
            end = min(seq, i + window // 2 + 1)
            mask[start:end, i] = 1
        mask = mask.unsqueeze(0).unsqueeze(0)  # [1, 1, seq, seq]
        return mask

window = 3
seq = inputs["input_ids"][0].shape[0]
mask = create_local_attention_mask(seq)
print(seq)
print(mask)

6
tensor([[[[ True,  True, False, False, False, False],
          [ True,  True,  True, False, False, False],
          [False,  True,  True,  True, False, False],
          [False, False,  True,  True,  True, False],
          [False, False, False,  True,  True,  True],
          [False, False, False, False,  True,  True]]]])


### Next Word Prediction

In [None]:
model = CustomModel()
outputs = model(**inputs, local_attn=True)


logits = outputs.logits  # Raw logits (scores for each token in the vocabulary)
hidden_states = outputs.hidden_states  # Optional, if return_dict=True and output_hidden_states=True
attention_weights = outputs.attentions  # Optional, if return_dict=True and output_attentions=True
print("Logits shape:", logits.shape) # [batch_size, seq_length, vocab_size]
predicted_token_ids = torch.argmax(logits, dim=-1)  # [batch, seq]
decoded_output = tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=True)

print("Decoded output:", decoded_output)

Logits shape: torch.Size([1, 6, 50272])
Decoded output: ["\n you're the game,"]


In [None]:
base_model = AutoModelForCausalLM.from_pretrained('facebook/opt-350m')

# Example usage
base_outputs = base_model(**inputs)
logits_base_model = base_outputs.logits
pred_token_id_base_model = torch.argmax(logits_base_model ,dim=-1)
decoded_output_base_model = tokenizer.batch_decode(pred_token_id_base_model, skip_special_tokens=True)
print("Logits shape(Base model):", logits_base_model .shape)
print("Decoded output(Base model):", decoded_output_base_model)

Logits shape(Base model): torch.Size([1, 6, 50272])
Decoded output(Base model): ["\n you're the game,"]


In [None]:
#