# install & docs

https://github.com/arogozhnikov/einops

`pip install einops`

einops 支持 **numpy, pytorch, tensorflow, keras, jax** 等

einops所有操作的pattern的axis之间必须有空格

```python
einops.einsum
    "i j, k j -> i k" ✅
    "i j, k j -> ik"  ❌

einops.rearrange
    "t (b c) -> b c t" ✅
    "t (b c) -> b ct"  ❌
```


In [28]:
import numpy as np
import torch
from einops import einsum, rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x14c9d490c30>

# einsum same as torch.einsum

In [29]:
a = np.random.random((1, 196, 768))
b = np.random.random((1, 768, 196))

In [30]:
a_t = torch.from_numpy(a)
b_t = torch.from_numpy(b)

## 不需要转置

In [39]:
c0 = np.einsum("b p c, b c k -> b p k", a, b)
c0.shape

(1, 196, 196)

In [40]:
c1 = einsum(a, b, "b p c, b c k -> b p k")
c1.shape

(1, 196, 196)

In [41]:
np.all(c0==c1)

True

In [42]:
c0_t = torch.einsum("b p c, b c k -> b p k", a_t, b_t)    # 空格可有可无
c0_t.shape

torch.Size([1, 196, 196])

In [44]:
c1_t = einsum(a_t, b_t, "b p c, b c k -> b p k")          # einops所有操作的pattern的axis之间必须有空格
c1_t.shape

torch.Size([1, 196, 196])

In [45]:
torch.all(c0_t==c1_t)

tensor(True)

## 需要转置

In [51]:
c2 = np.einsum("b p c, b k c -> b p k", a, b.swapaxes(1, 2))
c2.shape

(1, 196, 196)

In [52]:
c3 = einsum(a, b.swapaxes(1, 2), "b p c, b k c -> b p k")
c3.shape

(1, 196, 196)

In [53]:
np.all(c0==c1), np.all(c1==c2), np.all(c2==c3)

(True, True, True)

In [47]:
c2_t = torch.einsum("b p c, b k c -> b p k", a_t, b_t.transpose(1, 2))
c2_t.shape

torch.Size([1, 196, 196])

In [48]:
c3_t = einsum(a_t, b_t.transpose(1, 2), "b p c, b k c -> b p k")
c3_t.shape

torch.Size([1, 196, 196])

In [49]:
torch.all(c0_t==c1_t), torch.all(c1_t==c2_t), torch.all(c2_t==c3_t)

(tensor(True), tensor(True), tensor(True))

# rearrange: reshape, transopse, permute

In [9]:
input_tensor = np.ones((196, 10, 768))
# rearrange elements according to the pattern
rearrange(input_tensor, "t b c -> b c t").shape   # einops所有操作的pattern的axis之间必须有空格

(10, 768, 196)

In [10]:
input_tensor = torch.ones((196, 10, 768))
# rearrange elements according to the pattern
output_tensor0 = rearrange(input_tensor, "t b c -> b c t")   # einops所有操作的pattern的axis之间必须有空格
output_tensor0.shape

torch.Size([10, 768, 196])

In [11]:
rearrange_ = Rearrange("t b c -> b c t")
# rearrange elements according to the pattern
output_tensor1 = rearrange_(input_tensor)
output_tensor1.shape

torch.Size([10, 768, 196])

In [12]:
torch.all(output_tensor0==output_tensor1)

tensor(True)

In [13]:
input_tensor = torch.ones((196, 10*768))
# () 代表隐藏维度,需要指定内部值,不需要全部指定,可以推导出即可
output_tensor2 = rearrange(input_tensor, "t (b c) -> b c t", b=10)  # einops所有操作的pattern的axis之间必须有空格
output_tensor2.shape

torch.Size([10, 768, 196])

In [14]:
torch.all(output_tensor0==output_tensor2)

tensor(True)

## 变量使用多个字母

In [15]:
input_tensor = np.ones((1, 196, 3*768))
output_tensor = rearrange(input_tensor, "b n (qkv h d) -> qkv b h n d", qkv = 3, h = 12)    # 不使用空格会当做单个变量
output_tensor.shape

(3, 1, 12, 196, 64)

In [16]:
input_tensor = torch.ones((1, 196, 3*768))
output_tensor0 = rearrange(input_tensor, "b n (qkv h d) -> qkv b h n d", qkv = 3, h = 12)    # 不使用空格会当做单个变量
output_tensor0.shape

torch.Size([3, 1, 12, 196, 64])

In [17]:
output_tensor1 = input_tensor.reshape(1, 196, 3, 12, -1).permute(2, 0, 3, 1, 4)
output_tensor1.shape

torch.Size([3, 1, 12, 196, 64])

In [18]:
torch.all(output_tensor0==output_tensor1)

tensor(True)

# reduce: rearrange + min/max/sum/mean/prod

In [19]:
input_tensor = np.ones((1, 3, 10, 10))
# combine rearrangement and reduction
reduce(input_tensor, "b c (h h2) (w w2) -> b h w c", "mean", h2=2, w2=2).shape

(1, 5, 5, 3)

In [20]:
input_tensor = torch.ones((1, 3, 10, 10))
# combine rearrangement and reduction
output_tensor0 = reduce(input_tensor, "b c (h h2) (w w2) -> b h w c", "mean", h2=2, w2=2)
output_tensor0.shape

torch.Size([1, 5, 5, 3])

In [21]:
# combine rearrangement and reduction
reduce_ = Reduce("b c (h h2) (w w2) -> b h w c", "mean", h2=2, w2=2)
output_tensor1 = reduce_(input_tensor)
output_tensor1.shape

torch.Size([1, 5, 5, 3])

In [22]:
torch.all(output_tensor0==output_tensor1)

tensor(True)

# repeat: repeat, expand

In [23]:
input_tensor = np.ones((4, 4))
# copy along a new axis
output_tensor = repeat(input_tensor, "h w -> h w c", c=3)
output_tensor.shape

(4, 4, 3)

In [24]:
input_tensor = torch.ones((4, 4))
# copy along a new axis
output_tensor = repeat(input_tensor, "h w -> h w c", c=3)
output_tensor.shape

torch.Size([4, 4, 3])