## **Input Embedding**

```tokens``` : sentences are broken down to words or sub-words

```input_ids``` : each token has an index in the ```vocabulary``` we identify any new token by these ids

```embed_tokens``` : This refers to the embedding layer of the model. Each ```input_ids``` is converted into an n-dimensional dense vector (of size ```embed_dim or hidden_sz```). These vectors capture semantic relationships between tokens

During training or inference, multiple sequences (sentences) are grouped into a ```batch```. Each token in a sequence is converted to its corresponding embedding vector, and the model processes these batches in parallel for efficiency.

## **Positional Encoding**

These encodings are vectors that keep the location of a word in the sentence & helps the model understand the order and context of words in the sentence (since it lacks the recurrence feature found in RNNs to feed the input one at a time) Currently in models like Mistral ```RotaryEmbedding``` (a type of Dynamic positional encoding) is used inside ```self_attn``` block

## **Self-Attention Mechanism**

It calculates a weighted sum of the embeddings of all words in a sentence for each word. These weights are determined based on some learned “attention” scores between words. The terms with higher relevance to one another will receive higher “attention” weights.

**Query Vector**:

  - It represents the word or token for which the attention weights are being calculated.
  - The Query vector determines which parts of the input sequence should receive more attention.
  - Multiplying word embeddings with the Query vector is like asking, **"What should I pay attention to?"**

**Key Vector**:

  - It represents the set of words or tokens in the input sequence that are compared with the Query.
  - The Key vector helps identify the relevant or essential information in the input sequence.
  - Multiplying word embeddings with the Key vector is like asking, **"What is important to consider?"**

**Value Vector**:

- It contains the input sequence's associated information or features for each word or token.
- The Value vector provides the actual data that will be weighted and combined based on the attention weights calculated between the Query and Key.
- The Value vector answers the question, **"What information do we have?"**

**Attention Weights/Scores** : Calculated for each token and head.Each element represents the attention a particular head pays to a specific token when processing the input.

**Attention Output** : Combine the weights with the original input. It incorporates information from relevant tokens based on the attention scores, enriching each token's representation.

# SparseAttention
https://github.com/kyegomez/SparseAttention


# FlashAttention

https://github.com/kyegomez/FlashAttention20




# Replacing existing Attention Layer with Custom Attention Layer

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import OPTForCausalLM, OPTConfig

class SparseAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size):
        super(SparseAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.head_dim = embed_dim // num_heads

        assert self.head_dim * num_heads == embed_dim, "Embedding dimension must be divisible by num_heads"
        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)
        self.out_linear = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_length, embed_dim = x.size()
        assert embed_dim == self.embed_dim, "Input embedding dimension must match the module's embedding dimension"

        # Linear projections
        q = self.q_linear(x)  # (batch_size, seq_length, embed_dim)
        k = self.k_linear(x)  # (batch_size, seq_length, embed_dim)
        v = self.v_linear(x)  # (batch_size, seq_length, embed_dim)

        # Reshape for multi-head attention
        q = q.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_length, head_dim)
        k = k.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_length, head_dim)
        v = v.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_length, head_dim)

        # Compute attention scores
        attn_scores = torch.matmul(q, k.transpose(-2, -1))  # (batch_size, num_heads, seq_length, seq_length)

        # Apply local (sparse) attention mask
        mask = self._create_local_attention_mask(seq_length)
        attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
        attn_probs = F.softmax(attn_scores, dim=-1)  # (batch_size, num_heads, seq_length, seq_length)

        # Compute attention output
        attn_output = torch.matmul(attn_probs, v)  # (batch_size, num_heads, seq_length, head_dim)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, embed_dim)  # (batch_size, seq_length, embed_dim)

        # Final linear projection
        output = self.out_linear(attn_output)  # (batch_size, seq_length, embed_dim)
        return output

    def _create_local_attention_mask(self, seq_length):
        mask = torch.zeros(seq_length, seq_length)
        for i in range(seq_length):
            start = max(0, i - self.window_size // 2)
            end = min(seq_length, i + self.window_size // 2 + 1)
            mask[start:end, i] = 1
        return mask.unsqueeze(0).unsqueeze(0).to(torch.bool)  # (1, 1, seq_length, seq_length)


In [None]:
class CustomOPTModel(OPTForCausalLM):
    def __init__(self, config):
        super(CustomOPTModel, self).__init__(config)
        self._replace_attention_layers()

    def _replace_attention_layers(self):
        for layer_name, layer_module in self.named_modules():
            if isinstance(layer_module, nn.MultiheadAttention):
                # Replace the self-attention layer with SparseAttention
                setattr(self, layer_name, SparseAttention(
                    embed_dim=layer_module.embed_dim,
                    num_heads=layer_module.num_heads,
                    window_size=3  # Example window size
                ))


In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

# Load tokenizer appropriate for your model
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')

# Example input
inputs = tokenizer("Hello, my dog is cute", return_tensors='pt')

# Load pre-trained model configuration and model
config = AutoConfig.from_pretrained('facebook/opt-350m')
model = AutoModelForCausalLM.from_pretrained('facebook/opt-350m')

# Example usage
outputs = model(**inputs)
print(inputs.items())
print(inputs.tokens())
print(outputs.logits.shape)


dict_items([('input_ids', tensor([[    2, 31414,     6,   127,  2335,    16, 11962]])), ('attention_mask', tensor([[1, 1, 1, 1, 1, 1, 1]]))])
['</s>', 'Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute']
torch.Size([1, 7, 50272])



**Model Structure**: The code assumes the model has a modular structure, typically found in transformer models.

**Layer Access**: Accessing and replacing layers depends on the model’s architecture. For instance, in BERT, you access encoder layers through model.encoder.layer, while in GPT or OPT, you might need to adjust according to the specific layer structure.

**Modification Scope**: This method can be adapted to replace other components like feed-forward layers, normalization layers, etc.

**Model-Specific Adjustments**: The exact implementation may vary based on the model. For instance, GPT models use transformer.h instead of encoder.layer

# General pattern for Layer replacement

In [None]:
# import torch
# import torch.nn as nn
# from transformers import BertModel

# class CustomAttention(nn.Module):
#     def __init__(self, embed_dim, num_heads):
#         super(CustomAttention, self).__init__()
#         self.multihead_attention = nn.MultiheadAttention(embed_dim, num_heads)

#     def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None,
#                 encoder_attention_mask=None, past_key_value=None, output_attentions=False):
#         # Use the multihead_attention module to perform the attention operation
#         return self.multihead_attention(hidden_states, hidden_states, hidden_states,
#                                         attn_mask=attention_mask, key_padding_mask=attention_mask)[0]


## Custom LayerNorm + Residual connection + Attention

In [None]:
import torch
import torch.nn as nn
from transformers import BertModel

class CustomLayerNorm(nn.LayerNorm):
    def __init__(self, normalized_shape, eps=1e-12):
        super(CustomLayerNorm, self).__init__(normalized_shape, eps=eps)

    def forward(self, x):
        # Custom behavior, if needed, can be added here
        return super(CustomLayerNorm, self).forward(x)


class CustomResidualConnection(nn.Module):
    def __init__(self, dropout_prob=0.1):
        super(CustomResidualConnection, self).__init__()
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x, residual):
        # Custom residual connection implementation
        return x + self.dropout(residual)


class CustomAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(CustomAttention, self).__init__()
        self.multihead_attention = nn.MultiheadAttention(embed_dim, num_heads)
        self.layer_norm = CustomLayerNorm(normalized_shape=embed_dim)
        self.residual_connection = CustomResidualConnection()

    def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None,
                encoder_attention_mask=None, past_key_value=None, output_attentions=False):
        # Perform attention operation
        attention_output, _ = self.multihead_attention(
            hidden_states, hidden_states, hidden_states,
            attn_mask=attention_mask, key_padding_mask=attention_mask
        )

        # Apply layer normalization and residual connection
        normalized_output = self.layer_norm(attention_output)
        residual_output = self.residual_connection(normalized_output, hidden_states)

        # Return as a tuple
        return (residual_output,)


## Integration of Custom Layers

In [None]:
from transformers import AutoModel, AutoConfig

class CustomModel(nn.Module):
    def __init__(self, model_name):
        super(CustomModel, self).__init__()
        self.model = AutoModel.from_pretrained(model_name)
        self._replace_attention_layers()

    def _replace_attention_layers(self):
        for layer_name, layer_module in self.named_modules():
            if isinstance(layer_module, nn.MultiheadAttention):
                # Replace the self-attention layer
                embed_dim = layer_module.embed_dim
                num_heads = layer_module.num_heads
                setattr(self, layer_name, CustomAttention(embed_dim, num_heads))
                print(f"Replaced attention layer {layer_name}")

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None,
                head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None,
                past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None,
                return_dict=None):
        # Ensure the forward method matches the input signature of BERT
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )
        return outputs

In [None]:
# Example usage
model_name = 'bert-base-uncased'
custom_model = CustomModel(model_name)

# Example input (adjust according to the model's tokenizer)
input_ids = torch.tensor([[101, 1045, 2064, 1005, 1055, 1037, 1000, 102]], dtype=torch.long)
attention_mask = torch.ones_like(input_ids)

outputs = custom_model(input_ids=input_ids, attention_mask=attention_mask)
print(outputs.keys())
print(outputs.last_hidden_state.shape)  # Example output shape


odict_keys(['last_hidden_state', 'pooler_output'])
torch.Size([1, 8, 768])


# Encoder only Architecture - BERT

The model is comprised of both encoder and decoder components, with each component consisting of 12 layers. Additionally, The decoder component, in particular, contains an additional encoder_attn layer, referred to as cross-attention. The cross-attention component will condition the decoder’s output based on the encoder representations.

```python
BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)
```


In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel, GPT2Model

# def print_trainable_parameters(model):
#     total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
#     print(f"Total Trainable Parameters: {total_params}")

gpt2 = AutoModel.from_pretrained("gpt2")


def prepare_for_peft(model):
    for param in model.parameters():
        param.requires_grad = False  # freeze the model - train adapters later
        if param.dim() == 1:
            # cast the small parameters (e.g. layernorm) to fp32 for stability
            param.data = param.data.to(torch.float32)

    # Commented this line because it's not a valid function for GPT2Model
    # model.gradient_checkpointing_enable()  # reduce number of stored activations

    model.config.gradient_checkpointing = True  # enable gradient checkpointing
    model.config.use_cache = False  # disable cache for memory efficiency
    model.config.output_hidden_states = False  # set to True if you want hidden states
    model.config.output_attentions = False  # set to True if you want attention weights

    # No need to define a separate class, we can use nn.Sequential directly
    model.lm_head = nn.Sequential(nn.Linear(model.config.hidden_size, model.config.vocab_size))
    return model

print("-"*250)
gpt2 = prepare_for_peft(gpt2)
print(gpt2)



```python

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
trainable params: 124439808 || all params: 124439808 || trainable%: 100.0

now after preparing for peft a lm_head is added with above model & only part of it is trained

GPT2Model(
  (lm_head): Sequential(
    (0): Linear(in_features=768, out_features=50257, bias=True)
  )
)
trainable params: 38647633 || all params: 163087441 || trainable%: 23.697491825872724

```

# Local Attention & Global Attention

https://github.com/lucidrains/local-attention


### Local vs Global Attention

**Local Attention** focuses on a subset of the input sequence within a fixed-size window, offering computational efficiency by reducing the scope of attention. It’s useful in tasks like language modeling and image processing, where nearby context is more relevant.

**Global Attention** considers all positions of the input, providing a comprehensive view of the entire sequence. It’s more resource-intensive but excels in capturing long-range dependencies, such as in machine translation or image classification.

### Comparison:
- **Scope**: Local attention focuses on a neighborhood; global attends to all positions.
- **Efficiency**: Local is computationally efficient, global is more costly but captures broader context.
- **Use**: Local is for tasks with local dependencies, while global is for tasks needing long-range context.

### Example:
- **LocalAttention Class**: Implements a custom attention layer using a local window mask.
- **CustomTransformer Class**: Combines local and global attention mechanisms with a feedforward network.

This framework demonstrates how local and global attention can be integrated into a transformer model for various applications.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Custom Local Attention Layer
class LocalAttention(nn.Module):
    def __init__(self, embed_dim, window_size=5):
        super(LocalAttention, self).__init__()
        self.embed_dim = embed_dim
        self.window_size = window_size
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.size()
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

#         attention_scores = torch.matmul(q, k.transpose(-1, -2)) / (self.embed_dim ** 0.5)
        print(attention_scores.shape)
        # Apply local window mask
        mask = torch.zeros_like(attention_scores)
        mask[:, :, :self.window_size] = -float('inf')
        print(mask.shape)
        attention_scores = attention_scores + mask

        attention_weights = F.softmax(attention_scores, dim=-1)
        context = torch.matmul(attention_weights, v)

        return context

# Custom Transformer Model using both Local and Global Attention
class CustomTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size=5):
        super(CustomTransformer, self).__init__()
        self.local_attention = LocalAttention(embed_dim, window_size)
        self.global_attention = nn.MultiheadAttention(embed_dim, num_heads)
        self.feedforward = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.ReLU(),
            nn.Linear(4 * embed_dim, embed_dim)
        )

    def forward(self, x):
        # Local Attention
        local_context = self.local_attention(x)

        # Global Attention
        global_context, _ = self.global_attention(x, x, x)

        # Combine local and global contexts
        combined_context = local_context + global_context

        # Feedforward layer
        output = self.feedforward(combined_context)

        return output

# Example usage
if __name__ == "__main__":
    embed_dim = 16
    num_heads = 4
    window_size = 5
    batch_size = 8
    seq_len = 10

    # Create an instance of the custom transformer model
    model = CustomTransformer(embed_dim, num_heads, window_size)

    # Generate a dummy input tensor
    input_tensor = torch.randn(batch_size, seq_len, embed_dim)

    # Forward pass through the model
    output_tensor = model(input_tensor)

    print("Input Tensor Shape:", input_tensor.shape)
    print("Output Tensor Shape:", output_tensor.shape)


torch.Size([8, 10, 10])
torch.Size([8, 10, 10])
Input Tensor Shape: torch.Size([8, 10, 16])
Output Tensor Shape: torch.Size([8, 10, 16])


In [None]:
#

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer

# class SparseAttention(nn.Module):
#     def __init__(self, input_dim, embed_dim, num_heads, window):
#         super(SparseAttention, self).__init__()
#         self.head_dim = embed_dim // num_heads
#         self.num_heads = num_heads
#         self.window = window
#         self.q_proj = nn.Linear(input_dim, embed_dim)
#         self.k_proj = nn.Linear(input_dim, embed_dim)
#         self.v_proj = nn.Linear(input_dim, embed_dim)
#         self.output_proj = nn.Linear(embed_dim, input_dim)

#     def _create_local_attention_mask(self, seq):
#         mask = torch.zeros(seq, seq)  # [seq, seq]
#         for i in range(seq):
#             start = max(0, i - self.window // 2)
#             end = min(seq, i + self.window // 2 + 1)
#             mask[start:end, i] = 1
#         mask = mask.unsqueeze(0).unsqueeze(0).to(torch.bool)  # [1, 1, seq, seq]
#         return mask

#     def forward(self, x, local_attn=False):
#         batch, seq, embed_dim = x.size()
#         q = self.q_proj(x)
#         k = self.k_proj(x)
#         v = self.v_proj(x)

#         q = q.view(batch, seq, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # [batch,n_heads,seq,head_dim]
#         k = k.view(batch, seq, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
#         v = v.view(batch, seq, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

#         attn_scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / (self.head_dim ** 0.5) # [batch,n_heads,q_seq,k_seq]

#         if local_attn:
#             mask = self._create_local_attention_mask(seq)
#             attn_scores = attn_scores.masked_fill(~mask, float('-inf'))

#         attn_weights = torch.softmax(attn_scores, dim=-1)
#         attn_output = torch.matmul(attn_weights, v) # [batch,n_heads,q_seq,head_dim]

#         attn_output = attn_output.permute(0, 2, 1, 3).contiguous().view(batch, seq, -1) # [batch,q_seq,input_dim]
#         attn_output = self.output_proj(attn_output)

#         return attn_output, attn_weights

class SparseAttention(nn.Module):
    def __init__(self, input_dim, embed_dim, num_heads, window):
        super(SparseAttention, self).__init__()
        self.head_dim = embed_dim // num_heads
        self.num_heads = num_heads
        self.window = window
        self.q_proj = nn.Linear(input_dim, embed_dim)
        self.k_proj = nn.Linear(input_dim, embed_dim)
        self.v_proj = nn.Linear(input_dim, embed_dim)
        self.output_proj = nn.Linear(embed_dim, input_dim)

    def _create_local_attention_mask(self, seq):
        mask = torch.zeros(seq, seq, dtype=torch.bool)  # [seq, seq]
        for i in range(seq):
            start = max(0, i - self.window // 2)
            end = min(seq, i + self.window // 2 + 1)
            mask[start:end, i] = 1
        mask = mask.unsqueeze(0).unsqueeze(0)  # [1, 1, seq, seq]
        return mask

    def forward(self, x, local_attn=False):
        batch_size, seq_len, embed_dim = x.size() # [batch,seq_len,head_dim * num_heads]
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # Reshape and transpose for multi-head attention
        q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # [batch, num_heads, seq_len, head_dim]
        k = k.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # [batch, num_heads, seq_len, head_dim]
        v = v.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # [batch, num_heads, seq_len, head_dim]

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)  # [batch, num_heads, seq_len, seq_len]

        if local_attn:
            mask = self._create_local_attention_mask(seq_len)
            attn_scores = attn_scores.masked_fill(mask, float('-inf'))

        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)  # [batch, num_heads, seq_len, head_dim]

        attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)  # [batch, seq_len, embed_dim]
        attn_output = self.output_proj(attn_output)

        return attn_output, attn_weights

class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.model = AutoModelForCausalLM.from_pretrained('facebook/opt-350m')
        self._replace_attention_layer() # can add more custom layers like this

    def _replace_attention_layer(self):
        for name, module in self.model.named_modules():
            if isinstance(module, nn.MultiheadAttention):
                # Extract parameters from the existing module
                input_dim = module.embed_dim
                num_heads = module.num_heads
                # Create and replace the attention layer
                new_attention_layer = SparseAttention(
                    input_dim=input_dim,
                    embed_dim=input_dim,
                    num_heads=num_heads,
                    window=4,  # Example window size

                )
                setattr(self.model, name, new_attention_layer)

    def forward(self,*args, local_attn=False, **kwargs):
        return self.model(*args,**kwargs)


# Tokenization and Model Execution
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
inputs = tokenizer("If you modify the model",return_tensors='pt')
inputs.items()

dict_items([('input_ids', tensor([[    2,  1106,    47, 23209,     5,  1421]])), ('attention_mask', tensor([[1, 1, 1, 1, 1, 1]]))])

In [None]:
inputs["input_ids"][0].shape[0]

6

In [None]:
def create_local_attention_mask(seq,window = 3):
        mask = torch.zeros(seq, seq, dtype=torch.bool)  # [seq, seq]
        for i in range(seq):
            start = max(0, i - window // 2)
            end = min(seq, i + window // 2 + 1)
            mask[start:end, i] = 1
        mask = mask.unsqueeze(0).unsqueeze(0)  # [1, 1, seq, seq]
        return mask

window = 3
seq = inputs["input_ids"][0].shape[0]
mask = create_local_attention_mask(seq)
print(seq)
print(mask)

6
tensor([[[[ True,  True, False, False, False, False],
          [ True,  True,  True, False, False, False],
          [False,  True,  True,  True, False, False],
          [False, False,  True,  True,  True, False],
          [False, False, False,  True,  True,  True],
          [False, False, False, False,  True,  True]]]])


### Next Word Prediction

In [None]:
model = CustomModel()
outputs = model(**inputs, local_attn=True)


logits = outputs.logits  # Raw logits (scores for each token in the vocabulary)
hidden_states = outputs.hidden_states  # Optional, if return_dict=True and output_hidden_states=True
attention_weights = outputs.attentions  # Optional, if return_dict=True and output_attentions=True
print("Logits shape:", logits.shape) # [batch_size, seq_length, vocab_size]
predicted_token_ids = torch.argmax(logits, dim=-1)  # [batch, seq]
decoded_output = tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=True)

print("Decoded output:", decoded_output)

Logits shape: torch.Size([1, 6, 50272])
Decoded output: ["\n you're the game,"]


In [None]:
base_model = AutoModelForCausalLM.from_pretrained('facebook/opt-350m')

# Example usage
base_outputs = base_model(**inputs)
logits_base_model = base_outputs.logits
pred_token_id_base_model = torch.argmax(logits_base_model ,dim=-1)
decoded_output_base_model = tokenizer.batch_decode(pred_token_id_base_model, skip_special_tokens=True)
print("Logits shape(Base model):", logits_base_model .shape)
print("Decoded output(Base model):", decoded_output_base_model)

Logits shape(Base model): torch.Size([1, 6, 50272])
Decoded output(Base model): ["\n you're the game,"]


In [None]:
#