<a href="https://colab.research.google.com/github/Mantissagithub/Kv-cache/blob/main/Kv_cache_nn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

import torch.nn as nn
import torch

class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

        # Initialize the cache
        self.cache = {}

    def forward(self, x):
        # Check if the input is in the cache
        if str(x.tolist()) in self.cache:
            return self.cache[str(x.tolist())]

        out = self.fc1(x)
        out = torch.sigmoid(out)
        out = self.fc2(out)

        # Add the output to the cache
        self.cache[str(x.tolist())] = out

        return out


# Create an instance of the neural network
model = SimpleNN(input_size=10, hidden_size=5, output_size=3)

# Example input
x = torch.randn(1, 10)

# Forward pass
output = model(x)
print(output)

tensor([[0.2988, 0.0074, 0.0145]], grad_fn=<AddmmBackward0>)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionLayer(nn.Module):
    def __init__(self, embed_size, heads):
        super(AttentionLayer, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size must be divisible by heads"

        self.values = nn.Linear(embed_size, embed_size, bias=False)
        self.keys = nn.Linear(embed_size, embed_size, bias=False)
        self.queries = nn.Linear(embed_size, embed_size, bias=False)
        self.fc_out = nn.Linear(embed_size, embed_size)

        # Initialize the cache
        self.cache = {'keys': [], 'values': []}

    def forward(self, x):
        N, seq_length, _ = x.shape

        # Compute keys and values
        keys = self.keys(x)
        values = self.values(x)

        # Store keys and values in the cache
        self.cache['keys'].append(keys)
        self.cache['values'].append(values)

        queries = self.queries(x)

        # Calculate attention scores
        energy = torch.einsum("nqhd,nkhd->nqk", queries.view(N, seq_length, self.heads, self.head_dim),
                               keys.view(N, seq_length, self.heads, self.head_dim))
        attention = F.softmax(energy / (self.embed_size ** (1 / 2)), dim=2)

        # Weighted sum of values
        out = torch.einsum("nqk,nkhd->nqhd", attention, values.view(N, seq_length, self.heads, self.head_dim))
        out = out.reshape(N, seq_length, self.embed_size)

        return self.fc_out(out)

class SimpleTransformer(nn.Module):
    def __init__(self, embed_size, heads, num_layers):
        super(SimpleTransformer, self).__init__()
        self.layers = nn.ModuleList(
            [AttentionLayer(embed_size, heads) for _ in range(num_layers)]
        )

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# Example usage
if __name__ == "__main__":
    embed_size = 64  # Embedding size
    heads = 8        # Number of attention heads
    num_layers = 4   # Number of transformer layers

    model = SimpleTransformer(embed_size, heads, num_layers)

    # Example input (batch_size=1, seq_length=10, embed_size=64)
    x = torch.randn(1, 10, embed_size)

    # Forward pass
    output = model(x)
    print(output)

tensor([[[-0.0409, -0.0239, -0.0342, -0.0294, -0.1071, -0.1014, -0.0431,
          -0.0723,  0.0588,  0.0783, -0.0268,  0.0219,  0.0176,  0.0707,
           0.0550, -0.0693,  0.1000, -0.0458, -0.0565,  0.0418, -0.0349,
           0.0107, -0.0377,  0.0795, -0.0905, -0.1048, -0.0808, -0.0035,
          -0.1311, -0.0102, -0.0910,  0.1705,  0.1249, -0.1070, -0.0981,
          -0.0302, -0.0847, -0.0929, -0.0434,  0.0531,  0.0366, -0.0235,
          -0.0118,  0.0608,  0.0411, -0.0035,  0.0136, -0.1386,  0.0021,
           0.0931, -0.0306,  0.0929, -0.0049,  0.1363, -0.1158,  0.0302,
          -0.0619,  0.1215, -0.0514,  0.1102,  0.0071,  0.0136, -0.1052,
           0.1326],
         [-0.0409, -0.0239, -0.0342, -0.0294, -0.1071, -0.1014, -0.0431,
          -0.0723,  0.0588,  0.0783, -0.0268,  0.0219,  0.0176,  0.0707,
           0.0550, -0.0693,  0.1000, -0.0458, -0.0565,  0.0418, -0.0349,
           0.0107, -0.0377,  0.0795, -0.0905, -0.1048, -0.0808, -0.0035,
          -0.1311, -0.0102, -0.