# Tensor manipulation using einops

In this notebook we will use einops to rewrite typical deep learning operations in a more clear and concise way. First, let's intall einops:

In [1]:
!pip install einops

import torch
import einops
from einops import rearrange, reduce, einsum

Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0


## Flattening a tensor

This operation is used typically before fully connected layers.

The size of input tensor will be b c h w (batch of images). And the size ouput should be b (c h w)

In [3]:
x = torch.randn(4,3,6,7)

# write the einops operation here
y = rearrange(x, 'b c h w -> b (c h w)')

print(y.shape)

torch.Size([4, 126])


## Pooling

Pooling is typically used to reduce the spatial size of a tensor in convolutional networks.

Rewrite pooling2d using einops:

In [8]:
x = torch.randn(4,3,10,10)

import  torch.nn.functional as F

y = F.avg_pool2d(x,(2,2),2)
print(y.shape)

#repeat average pooling using einops
y2 = reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'mean', h2=2, w2=2)

print("the result of difference should be close to zero:")
print((y-y2).abs().max())


torch.Size([4, 3, 5, 5])
the result of difference should be close to zero:
tensor(1.1921e-07)


## Patch embedding

Patch embedding is the first operation in vision transformers. The steps are:

1. Split image into patches
2. Flatten patches
3. Apply a linear transformation

In many implementations, patch embedding is done using a convolutional layer as follows:

In [11]:
from torch import nn
from torch import Tensor

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
        self.patch_size = patch_size
        super().__init__()
        self.projection =  nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)


    def forward(self, x: Tensor) -> Tensor:
        x = self.projection(x)
        # reshape from b embed_size h//patch_size w//patch_size to b ( h//patch_size w//patch_size) embed_size
        x = x.view(x.shape[0], x.shape[1], -1) # these two lines can be done with einops too!
        x = x.transpose(-2,-1)

        return x

In [10]:
x = torch.randn(4,3,224,224)

pe = PatchEmbedding(3,16,768)

y = pe(x)

print(y.shape)


torch.Size([4, 196, 768])


In [16]:
from einops.layers.torch import Rearrange

class PatchEmbedding_einops(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
        self.patch_size = patch_size
        super().__init__()
        #extract and flatten patches using Rearrange Layer
        self.rearrange = Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)
        self.projection =  nn.Linear(in_channels * patch_size * patch_size, emb_size)


    def forward(self, x: Tensor) -> Tensor:
        x = self.rearrange(x)
        x = self.projection(x)

        return x

In [17]:
x = torch.randn(4,3,224,224)

pe_einops = PatchEmbedding_einops(3,16,768)

y = pe_einops(x)

print(y.shape)

torch.Size([4, 196, 768])


# Self attention

Lets implement a self attention block using basic torch operations


In [None]:
class SelfAttention(nn.Module):
    def __init__(self, input_dim: int = 768, dropout: float = 0):
        super().__init__()
        self.input_dim = input_dim
        # fuse the queries, keys and values in one matrix (more efficient)
        self.qkv = nn.Linear(input_dim, input_dim * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(input_dim, input_dim)

    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        x_qkv =self.qkv(x)
        # split keys, queries and values from x_qkv
        queries, keys, values = torch.chunk(x_qkv,3, dim=-1)


        scaling = self.input_dim ** (1/2)
        energy = torch.bmm(queries, keys.transpose(-1,-2)) #bactched matrix multiplication
        if mask is not None:
          fill_value = torch.finfo(torch.float32).min
          energy.mask_fill(~mask, fill_value)

        att = F.softmax(energy, dim = -1) / scaling
        att = self.att_drop(att)

        out = torch.bmm(att, values)  #bactched matrix multiplication

        return out


x = torch.randn(5,100,768)
sa = SelfAttention(768,0.0)

y = sa(x)
print(y.shape)




torch.Size([5, 100, 768])


Now rewrite SelfAttention using einops

In [21]:

class SelfAttention_einops(nn.Module):
    def __init__(self, input_dim: int = 768, dropout: float = 0):
        super().__init__()
        self.input_dim = input_dim
        # fuse the queries, keys and values in one matrix (more efficient)
        self.qkv = nn.Linear(input_dim, input_dim * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(input_dim, input_dim)

    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        x_qkv =self.qkv(x)
        # split keys, queries and values from x_qkv
        # Reshape to 3 b n embed_dim using einops
        x_qkv = rearrange(x_qkv, 'b n (qkv e) -> qkv b n e', qkv=3)

        queries, keys, values = x_qkv[0], x_qkv[1], x_qkv[2]

        print(queries.shape, keys.shape, values.shape)

        scaling = self.input_dim ** (1/2)
        # calculate energy using einsum
        energy = einsum('b q d, b d k -> b q k', queries, keys)
        if mask is not None:
          fill_value = torch.finfo(torch.float32).min
          energy.mask_fill(mask==1, fill_value)

        att = F.softmax(energy, dim = -1) / scaling
        att = self.att_drop(att)

        #multiply att with values using einops
        out = None
        return out


x = torch.randn(5,100,768)
sa = SelfAttention_einops(768,0.0)

y = sa(x)
print(y.shape)

torch.Size([5, 100, 768]) torch.Size([5, 100, 768]) torch.Size([5, 100, 768])


ValueError: The last argument passed to `einops.einsum` must be a string, representing the einsum pattern.

## Multihead attention

Now let's write multihead attention

In [None]:
import torch
import torch.nn as nn

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, input_dim, num_heads, dropout=0.0):
        super(MultiHeadSelfAttention, self).__init__()
        assert input_dim % num_heads == 0, "Input dimension must be divisible by the number of heads"

        self.input_dim = input_dim
        self.num_heads = num_heads
        self.head_dim = input_dim // num_heads

        self.qkv = nn.Linear(input_dim, 3*input_dim)


        self.dropout = nn.Dropout(dropout)
        self.output_projection = nn.Linear(input_dim, input_dim)

    def forward(self, x, mask=None):
        batch_size, seq_len, input_dim = x.size()

        x_qkv = self.qkv(x) # only one linear layer
                            # x_qkv contains queries, keys values for all heads

        # split in querys, keys, values
        queries, keys, values = torch.chunk(x_qkv,3, dim=-1)

        # Reshape queries, keys, and values to split heads, the ouput dim is:  b num_heads seq_len head_dim
        queries = queries.view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        keys = keys.view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        values = values.view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        # Compute scaled dot-product attention
        scaling = self.input_dim ** (1/2)
        energy = torch.matmul(queries, keys.transpose(-2, -1)) / scaling
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)

        att = F.softmax(energy, dim=-1)
        att = self.dropout(att)

        # Apply attention to values
        out = torch.matmul(att, values)

        #importat detail: contigous is required before view (or use reshape!)
        out = out.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, input_dim)


        return out


x = torch.randn(5,100,768)
msa = MultiHeadSelfAttention(768,12)

y = msa(x)
print(y.shape)


torch.Size([5, 100, 768])


Now reimplement the previous class using einops

In [None]:
import torch
import torch.nn as nn

class MultiHeadSelfAttention_einops(nn.Module):
    def __init__(self, input_dim, num_heads, dropout=0.0):
        super(MultiHeadSelfAttention_einops, self).__init__()
        assert input_dim % num_heads == 0, "Input dimension must be divisible by the number of heads"

        self.input_dim = input_dim
        self.num_heads = num_heads
        self.head_dim = input_dim // num_heads

        self.qkv = nn.Linear(input_dim, 3*input_dim)


        self.dropout = nn.Dropout(dropout)
        self.output_projection = nn.Linear(input_dim, input_dim)

    def forward(self, x, mask=None):
        batch_size, seq_len, input_dim = x.size()

        x_qkv = self.qkv(x) # only one linear layer
                            # x_qkv contains queries, keys values for all heads

        # split in querys, keys, values using rearrange as before
        x_qkv =

        queries, keys, values = x_qkv[0], x_qkv[1], x_qkv[2]



        # Reshape queries, keys, and values to split heads, the ouput dim is:  b num_heads seq_len head_dim
        queries =
        keys =
        values =

        # Compute scaled dot-product attention usin einops
        scaling = self.input_dim ** (1/2)
        energy =
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)

        att = F.softmax(energy, dim=-1)
        att = self.dropout(att)

        # Apply attention to values usin einops
        out =

        # rearrange out to join heads
        out =



        return out


x = torch.randn(5,100,768)
msa = MultiHeadSelfAttention_einops(768,12)

y = msa(x)
print(y.shape)


torch.Size([5, 100, 768])


64.0