# Implementation of Fig.2 : Attention’s RNN Cell

In [1]:
import torch
# Seed for reproducibility
torch.manual_seed(42)


<torch._C.Generator at 0x209cbf850b0>

In [10]:
import torch.nn as nn

class AttentionRNNCell(nn.Module):
    def __init__(self, d_model):
        super(AttentionRNNCell, self).__init__()
        self.d_model = d_model  # Dimensionality of keys/queries/values
    
    def forward_iterative(self, q, k, v, prev_a=None, prev_c=None, prev_m=None):
        """
        Perform a single step of the recurrent attention computation.
        Args:
            q (Tensor): Query vector, shape (batch_size, d_model).
                - Corresponds to the query of the current time step.

            k (Tensor): Key vector, shape (batch_size, d_model).
            v (Tensor): Value vector, shape (batch_size, d_model).
                - Corresponds to the values and keys from the first step to the current time step.


            prev_a (Tensor): Previous a_k, shape (batch_size, d_model). Defaults to 0.
            prev_c (Tensor): Previous c_k, shape (batch_size, 1). Defaults to 0.
            prev_m (Tensor): Previous m_k, shape (batch_size, 1). Defaults to 0.
        Returns:
            a_k (Tensor), c_k (Tensor), m_k (Tensor)
        """
        batch_size = q.size(0)

        # For debugging purposes
        assert k.shape[-1] == v.shape[-1] == self.d_model

        # Initialize previous states if not provided
        if prev_a is None:
            prev_a = torch.zeros(batch_size, self.d_model).to(k.device)
        if prev_c is None:
            prev_c = torch.zeros(batch_size, self.d_model).to(k.device)
        if prev_m is None:
            prev_m = torch.full((batch_size, 1), 0).to(k.device)

        # Compute scores s_k = q . k^T
        s_k = torch.sum(q * k, dim=-1, keepdim=True)  # Shape: (batch_size, 1)

        # Update m_k (max cumulative score)
        m_k = torch.max(s_k, prev_m)

        # Compute exp terms for stability
        exp_term1 = torch.exp(prev_m - m_k)  # Shape: (batch_size, 1)
        exp_term2 = torch.exp(s_k - m_k)    # Shape: (batch_size, 1)

        # Update a_k and c_k
        a_k = prev_a * exp_term1 + v * exp_term2  # Shape: (batch_size, d_model)
        c_k = prev_c * exp_term1 + exp_term2      # Shape: (batch_size, 1)

        return a_k, c_k, m_k




## Itereative Computation
- Recurrently token-by-token (i.e., sequentially) in O(1)
memory

In [None]:
class Model_NN_trainable_Model():
    def __init__(self, model,batch_size, input_dim, d_model, optimizer):
        self.d_model = d_model
        self.batch_size = batch_size
        self.optimizer = optimizer
        self.input_dim = input_dim
        self.model = AttentionRNNCell(d_model)
    def forward_QKV():
        

        return Q, K, V

In [11]:

# Configuration
batch_size = 1  # Number of sequences to process at once
input_dim = 10  # Sequence length (N=10)
d_model = 2     # Dimensionality of embeddings ( in text context), values representation

# Dummy Data
q = torch.randn(batch_size, d_model)  # Query vector
k = torch.randn(input_dim, batch_size, d_model)  # Key matrix
v = torch.randn(input_dim, batch_size, d_model)  # Value matrix

# Initialize Cell
cell = AttentionRNNCell(d_model)
a_k, c_k, m_k = None, None, None

# Process Iteratively
for i in range(input_dim):
    a_k, c_k, m_k = cell.forward_iterative(q, k[i, :, :], v[i, :, :], a_k, c_k, m_k)
    
import numpy as np
if c_k != np.zeros((1,1)):
    final_output = a_k/ c_k
    
print("Shape of a_k:", a_k.shape, "\nFinal a_k:\n", a_k)
print("Shape of c_k:", c_k.shape, "\nFinal c_k:\n", c_k)
print("Shape of m_k:", m_k.shape, "\nFinal m_k:\n", m_k)
print("Final output:\n", final_output)

Shape of a_k: torch.Size([1, 2]) 
Final a_k:
 tensor([[-0.5620,  1.3149]])
Shape of c_k: torch.Size([1, 2]) 
Final c_k:
 tensor([[4.0781, 4.0781]])
Shape of m_k: torch.Size([1, 1]) 
Final m_k:
 tensor([[0.]])
Final output:
 tensor([[-0.1378,  0.3224]])


## Parallel Computation
- Many-to-many RNN
    -  we can compute
{ak}N
k=1, {ck}N
k=1, and {mk}N
k=1 via the parallel scan
algorithm and 
**AFTERWARDS** combine ak and ck to compute Attention(q, x1:k).
<html>

### Image 1: MTM Diagram
![MTM Diagram](MTM.png "MTM Diagram Title")

### Image 2: Parallel Prefix Scan Diagram
- This is the first intuation behind the implementation, yet we are going to use a litteral implementation of it as it is

![Parallel Prefix Scan Diagram](ParallelPrefixScan.png "Parallel Prefix Scan Diagram Title")

</html>



In [83]:
# Define the parallel attention implementation as a module
class ParallelAttentionScan(nn.Module):
    def __init__(self):
        super(ParallelAttentionScan, self).__init__()

    def combine(self, mA, uA, wA, mB, uB, wB):
        mAB = torch.max(mA, mB)
        expA = torch.exp(mA - mAB)
        expB = torch.exp(mB - mAB)
        uAB = uA * expA + uB * expB
        wAB = wA * expA + wB * expB
        return mAB, uAB, wAB

    def forward(self, q, k, v):
        input_dim, batch_size, d_model = k.size()
        # {(m{i}, u{i}, w{i})}N,i=1 == {(si, 1, vi)}N,i=1
        
        s_initial = torch.sum(q * k[0], dim=-1, keepdim=True)
        m = s_initial
        u = torch.ones(batch_size, 1)
        w = v[0]

        for i in range(1, input_dim):
            m, u, w = self.combine(m, u, w, k[i], torch.exp(k[i] - m), v[i])

        return w / u
    
# Example input tensors
batch_size = 2
input_dim = 5
d_model = 3

q = torch.randn(batch_size, d_model)  # Query vector (batch_size, d_model)
k = torch.randn(input_dim, batch_size, d_model)  # Key matrix (input_dim, batch_size, d_model)
v = torch.randn(input_dim, batch_size, d_model)  # Value matrix (input_dim, batch_size, d_model)

# Run parallel attention scan
ATT_parallel = ParallelAttentionScan()
o = ATT_parallel(q, k, v)
print("Output shape:", o.shape)
print("Attention Output:\n", o)


Output shape: torch.Size([2, 3])
Attention Output:
 tensor([[-0.1760, -0.0023,  0.0350],
        [ 0.1519, -0.1678, -0.2512]])


# Iterative VS Parallel

In [84]:
# Seed for reproducibility  
torch.manual_seed(42)

# Configuration
batch_size = 1
input_dim = 100
d_model = 3

q = torch.randn(batch_size, d_model)  # Query vector
k = torch.randn(input_dim, batch_size, d_model)  # Key matrix
v = torch.randn(input_dim, batch_size, d_model)  # Value matrix

# Run parallel attention scan
ATT_parallel = ParallelAttentionScan()
parallel_output = ATT_parallel(q, k, v)

# Initialize RNN Cell
cell = AttentionRNNCell(d_model)
a_k, c_k, m_k = None, None, None

# Process Iteratively
for i in range(input_dim):
    a_k, c_k, m_k = cell.forward_iterative(q, k[i, :, :], v[i, :, :], a_k, c_k, m_k)

iterative_output = a_k / c_k  # Normalize the result to get final attention output

# Compute Absolute and Relative Errors
absolute_error = torch.abs(parallel_output - iterative_output)
relative_error = torch.abs(parallel_output - iterative_output) / torch.abs(iterative_output)

print("Parallel Output:", parallel_output)
print("Iterative Output:", iterative_output)

print("Absolute Error:", absolute_error)
print("Relative Error:", relative_error)

print("Mean Absolute Error:", torch.mean(absolute_error))
print("Mean Relative Error:", torch.mean(relative_error))

Parallel Output: tensor([[0.2147, 0.0529, 0.0588]])
Iterative Output: tensor([[ 0.0920, -0.0360,  0.0753]])
Absolute Error: tensor([[0.1227, 0.0890, 0.0165]])
Relative Error: tensor([[1.3332, 2.4698, 0.2192]])
Mean Absolute Error: tensor(0.0761)
Mean Relative Error: tensor(1.3407)


# Exporting into onnx

In [None]:
# # Instantiate the models
# parallel_model = ParallelAttentionScan()
# iterative_model = IterativeAttentionScan(d_model=3)

# # Dummy inputs
# batch_size = 2
# input_dim = 5
# d_model = 3
# q = torch.randn(batch_size, d_model)
# k = torch.randn(input_dim, batch_size, d_model)
# v = torch.randn(input_dim, batch_size, d_model)

# # Export Parallel Attention Model to ONNX
# torch.onnx.export(
#     parallel_model,
#     (q, k, v),
#     "parallel_attention_scan.onnx",
#     input_names=["q", "k", "v"],
#     output_names=["output"],
#     dynamic_axes={"k": {0: "input_dim"}, "v": {0: "input_dim"}},
#     opset_version=12,
# )

# # Export Iterative Attention Model to ONNX
# torch.onnx.export(
#     iterative_model,
#     (q, k, v),
#     "iterative_attention_scan.onnx",
#     input_names=["q", "k", "v"],
#     output_names=["output"],
#     dynamic_axes={"k": {0: "input_dim"}, "v": {0: "input_dim"}},
#     opset_version=12,
# )

# print("Both models have been exported as ONNX.")


## ONNX models Visualization
### Image 1: Iterative Model
![Iterative Model](itereative.png "")

### Image 2: Parallel Model
![Parallel Model](paralell.png "")

</html>

