# Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models

[arXiv](https://arxiv.org/abs/2402.19427)

## Model Architecture
All our models contain the following components: 
- (i) a residual block
- (ii) an MLP block
- (iii) a temporal-mixing block. 

While (i) and (ii) are the same across all models, we consider three temporal mixing blocks: global Multi-Query Attention (MQA), local (sliding-window) MQA and our proposed recurrent block. As part of the recurrent block we use the Real-Gated Linear Recurrent Unit (RG-LRU) – a novel recurrent layer inspired by the Linear Recurrent Unit [Orvieto et al., 2023b](https://arxiv.org/abs/2303.06349).

The residual block, as shown in Figure 2(a), defines the global structure of our models and is inspired by pre-norm Transformers (Xiong et al., 2020). After embedding the input sequence we pass it through $N$ such blocks ($N$ denoting the model depth), and then we apply RMSNorm [Zhang and Sennrich, 2019](https://arxiv.org/abs/1910.07467) to produce the final activations. To compute the token probabilities we apply a final linear layer followed by a softmax. The weights of this layer are shared with the input embedding layer.

## Residual block

![Griffin](https://arxiv.org/html/2402.19427v1/x3.png)

Figure 2: a) The main backbone of our mode architecture is the residual block, which is stacked $N$ times. b) The gated MLP block that we use. c) The recurrent block that we propose as an alternative to Multi Query Attention (MQA). It uses our proposed RG-LRU layer, defined in Section 2.4.

The residual block contains two components, applied in order. The first component takes the hidden state $\chi$ and applies an RMSNorm [Zhang and Sennrich, 2019](https://arxiv.org/abs/1910.07467), followed by the temporal-mixing block. We then merge the output with a skip connection from $\chi$ through addition. Similarly, the second component applies RMSNorm, followed by the MLP block and then merges its output with a skip connection from the input of the RMSNorm. This block is illustrated in Figure 2 (a).

In [1]:
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.nn import Embedding

## Gated MLP Block

![GatedMLP](https://github.com/AmbiTyga/Research-Bookmark/blob/main/Griffin/Gated%20MLP.png?raw=true)

In [2]:
class GatedMLPBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(GatedMLPBlock, self).__init__()

        self.linearA = nn.Linear(
            input_dim, hidden_dim,
            bias=False
        )
        self.linearB = nn.Linear(
            input_dim, hidden_dim,
            bias=False
        )

        self.linearCombined = nn.Linear(
            hidden_dim, input_dim,
            bias=False
        )

    def forward(self, x):
        outputA = torch.sigmoid(
            self.linearA(x)
        )
        gelu_out = F.gelu(outputA)
        outputB = self.linearB(x)
        combined = gelu_out * outputB
        outputCombined = self.linearCombined(combined)
        return outputCombined
    