In [2]:
import torch
import torch.nn as nn
import numpy as np

In [None]:
class Attention(nn.Module):
    '''
    Experimental implementation of the scaled dot-product attention mechanism
    and not yet fully implemented
    '''
    def __init__(self, embed_size, heads):
        super(Attention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embed size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask=None):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        query = query.reshape(N, query_len, self.heads, self.head_dim)

        energy  = torch.einsum("nqhd,nkhd->nhqk", [query, keys])
        # queries shape: (N, query_len, heads, head_dim)
        # keys shape: (N, key_len, heads, head_dim)
        # energy shape: (N, heads, query_len, key_len)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nqhd,nkhd->nhqk", [query, keys])
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, heads_dim)
        # out shape: (N, query_len, heads, head_dim)
        
        out = torch.einsum()
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, heads_dim) 
        # (N, query_len, heads, head_dim)

        # For demonstration purposes, we'll just return the reshaped tensors
        return values, keys, query, energy

In [10]:
# Example usage
# if __name__ == "__main__":
embed_size = 8
heads = 2
batch_size = 1
seq_length = 4

# Create dummy data
values = torch.randn(batch_size, seq_length, embed_size)
keys = torch.randn(batch_size, seq_length, embed_size)
query = torch.randn(batch_size, seq_length, embed_size)

print("Original shapes:")
print(f"Shape of the Values: {values.shape}")
print(f"Shape of the Keys: {keys.shape}")
print(f"Shape of the Query: {query.shape}")
# Create the Attention layer
attention = Attention(embed_size, heads)

# Forward pass
reshaped_values, reshaped_keys, reshaped_query, energy = attention(values, keys, query)

print("\nReshaped shapes:")
print("Reshaped Values shape:", reshaped_values.shape)
print("Reshaped Keys shape:", reshaped_keys.shape)
print("Reshaped Query shape:", reshaped_query.shape)
print(f"Shape of Energy: {energy.shape}")

Original shapes:
Shape of the Values: torch.Size([1, 4, 8])
Shape of the Keys: torch.Size([1, 4, 8])
Shape of the Query: torch.Size([1, 4, 8])

Reshaped shapes:
Reshaped Values shape: torch.Size([1, 4, 2, 4])
Reshaped Keys shape: torch.Size([1, 4, 2, 4])
Reshaped Query shape: torch.Size([1, 4, 2, 4])
Shape of Energy: torch.Size([1, 2, 4, 4])


### Use Einsum

General Rules when using Einsum


1.) **Repeating letters in different inputs means those values will be multiplied and those products will be the output**

2.) **Omitting a letter means that axis will be summed**

3.) **We can return the unsummed axes in any order that we like**

In [15]:
np.random.seed(23)

# A = np.random.rand(3, 3); A
A = np.array([[1,2,3],[4,5,6],[7,8,9]]); 
B = np.array([[2,3,4,1],[5,6,7,2],[8,9,10,7]]); 
M = np.empty((3,4))

for i in range(3):
    for j in range(4):
        total = 0
        for k in range(3):
            total += A[i,k] * B[k,j]
        
        M[i,j] = total
        


M

array([[ 36.,  42.,  48.,  26.],
       [ 81.,  96., 111.,  56.],
       [126., 150., 174.,  86.]])

In [16]:
M = np.einsum('ij,jk->ik', A, B); M

array([[ 36,  42,  48,  26],
       [ 81,  96, 111,  56],
       [126, 150, 174,  86]])

In [17]:
x = np.array([1,2,3])
sum_x = np.einsum('i->', x); sum_x

6

In [20]:
x = np.ones((5,4,3))
print(x.shape)
np.einsum('ijk->kji', x)
print(x.shape)

(5, 4, 3)
(5, 4, 3)


In [24]:
x = torch.rand((2,3)); 

##Permutation of tensors
# torch.einsum('ij->ji', x)
x

tensor([[0.0077, 0.8053, 0.0969],
        [0.6909, 0.6676, 0.2297]])

In [39]:
##Permutation of tensors
# torch.einsum('ij->ji', x)

## Summation of tensors
# torch.einsum('ij->', x)

## Column sum
# torch.einsum('ij->j', x)

## Row Sum
# torch.einsum('ij->i', x) 

## Matrix vector Multiplication
v = torch.rand((1,3))
# torch.einsum("ij,kj -> ik", x, v)

## Matrix Matrix Multiplication
torch.einsum("ij,kj -> ik", x, x)

tensor([[0.6579, 0.5652],
        [0.5652, 0.9758]])

In [38]:
x @ x.T

tensor([[0.6579, 0.5652],
        [0.5652, 0.9758]])