In [None]:
import torch
from einops import rearrange, einsum 

In [7]:
D = torch.randn(2, 3, 4)
A = torch.randn(3, 4)

In [11]:
Y = D @ A.T

In [9]:
print(Y.shape)

torch.Size([2, 3, 3])


In [12]:
Y1 = einsum(D, A, "batch seqence d_in, d_out d_in -> batch seqence d_out")

In [13]:
Y1 == Y

tensor([[[True, True, True],
         [True, True, True],
         [True, True, True]],

        [[True, True, True],
         [True, True, True],
         [True, True, True]]])

In [14]:
Y2 = einsum(D, A, "... d_in, d_out d_in -> ... d_out")

In [15]:
images = torch.randn(64, 128, 128, 3)  # (batch, height, width, channels)
dim_by = torch.linspace(start=0.0, end=1.0, steps=10)  # (10,)

In [None]:
dim_values = rearrange(dim_by, "dim_value      -> 1 dim_value 1 1 1")  # 广播

In [20]:
dim_values.shape

torch.Size([1, 10, 1, 1, 1])

In [22]:
images_rearr = rearrange(images, "b height width channel -> b 1 height width channel")

In [23]:
images_rearr.shape

torch.Size([64, 1, 128, 128, 3])

In [24]:
dimmed_images = images_rearr * dim_values

In [25]:
dimmed_images.shape

torch.Size([64, 10, 128, 128, 3])

In [30]:
dimmed_images2 = einsum(
    images, dim_by,
    "batch height width channel, dim_value -> batch dim_value height width channel"
)

In [None]:
dimmed_images == dimmed_images2

In [33]:
channels_last = torch.randn(64, 32, 32, 3)
B = torch.randn(32 * 32, 32 * 32)

In [34]:
channels_last_flat = channels_last.view(
    -1, channels_last.size(1) * channels_last.size(2), channels_last.size(3)
)

In [35]:
channels_last_flat.shape

torch.Size([64, 1024, 3])

In [36]:
channels_first_flat = channels_last_flat.transpose(1, 2)
channels_first_flat.shape

torch.Size([64, 3, 1024])

In [37]:
channels_first_flat_transformed = channels_first_flat @ B.T
channels_first_flat_transformed.shape

torch.Size([64, 3, 1024])

In [38]:
channel_last_flat_transformed = channels_first_flat_transformed.transpose(1, 2)
channel_last_flat_transformed.shape

torch.Size([64, 1024, 3])

In [39]:
channels_last_transformed = channel_last_flat_transformed.view(*channels_last.shape)
channels_last_transformed.shape

torch.Size([64, 32, 32, 3])

In [40]:
height = width = 32

In [42]:
channels_first = rearrange(channels_last, "b h w c -> b c (h w)")
channels_first.shape

torch.Size([64, 3, 1024])

In [44]:
channels_first_transformed = einsum(
    channels_first, B,
    "b c pin, p_out p_in -> b c p_out"
)
channels_first_transformed.shape

torch.Size([64, 3, 1024])

In [46]:
channels_last_transformed2 = rearrange(
    channels_first_transformed,
    "b c (h w) -> b h w c",
    h=height,w=width
)

In [38]:
from einops import rearrange, reduce, asnumpy, einsum
import numpy as np

In [50]:
x = np.random.RandomState(42).normal(size=[10, 32, 100, 200])
x.shape

(10, 32, 100, 200)

In [2]:
import torch

In [53]:
x = torch.from_numpy(x)
x.requires_grad = True

In [54]:
type(x), x.shape

(torch.Tensor, torch.Size([10, 32, 100, 200]))

In [55]:
y = rearrange(x, "b c h w -> b h w c")
y.shape

torch.Size([10, 100, 200, 32])

In [60]:
y0 = x
y1 = reduce(y0, "b c h w -> b c", "max")
y2 = rearrange(y1, "b c -> c b")
y3 = reduce(y2, "c b -> ", "sum")
y3.backward()
print(reduce(x.grad, "b c h w -> ", "sum"))

tensor(960., dtype=torch.float64)


In [62]:
y3_numpy = asnumpy(y3)
print(type(y3_numpy))

<class 'numpy.ndarray'>


In [63]:
y = rearrange(x, "b c h w -> b (c h w)")
y.shape

torch.Size([10, 640000])

In [64]:
y = rearrange(x, "b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=2, w1=2)
y.shape

torch.Size([10, 128, 50, 100])

In [65]:
y = rearrange(x, "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=2, w1=2)
y.shape

torch.Size([10, 8, 200, 400])

In [27]:
x = torch.ones((2,2))
y = torch.empty(2)

In [28]:
x_mean = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True))

In [29]:
x_mean.shape

torch.Size([2, 1])

In [34]:
x_mean

tensor([[1.],
        [1.]])

In [31]:
x_mean1 = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=False))

In [32]:
x_mean1.shape

torch.Size([2])

In [33]:
x_mean1

tensor([1., 1.])

In [52]:
a = torch.ones(2,3,4)
b = torch.ones(5,4)

In [64]:
x1 = einsum(a, b, "b s d_in, d_out d_in -> b s d_out")
x1

tensor([[[4., 4., 4., 4., 4.],
         [4., 4., 4., 4., 4.],
         [4., 4., 4., 4., 4.]],

        [[4., 4., 4., 4., 4.],
         [4., 4., 4., 4., 4.],
         [4., 4., 4., 4., 4.]]])

In [None]:
x2 = einsum(b, a, "d_out d_in1, b s d_in -> b s d_out")
x2

In [65]:
x2.shape

torch.Size([2, 3, 5])

In [56]:
x2 == x1

tensor([[[False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False]],

        [[False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False]]])

In [57]:
x3 = a @ b.T

In [59]:
x1 == x3

tensor([[[True, True, True, True, True],
         [True, True, True, True, True],
         [True, True, True, True, True]],

        [[True, True, True, True, True],
         [True, True, True, True, True],
         [True, True, True, True, True]]])

In [66]:
import torch
from einops import rearrange, einsum

In [75]:
d_k = 8
theta = 0.1
seq = 8

In [76]:
torch.arange(0, d_k, 2).to(dtype=torch.float32)

tensor([0., 2., 4., 6.])

In [77]:
freq_arrange = 1 / (theta ** (torch.arange(0, d_k, 2).to(dtype=torch.float32) / d_k))

In [78]:
freq_arrange

tensor([1.0000, 1.7783, 3.1623, 5.6234])

In [80]:
token_positions = torch.arange(seq)
token_positions

tensor([0, 1, 2, 3, 4, 5, 6, 7])

In [87]:
theta = einsum(token_positions, freq_arrange, "s, dk_half -> s dk_half")
theta

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.0000,  1.7783,  3.1623,  5.6234],
        [ 2.0000,  3.5566,  6.3246, 11.2468],
        [ 3.0000,  5.3348,  9.4868, 16.8702],
        [ 4.0000,  7.1131, 12.6491, 22.4937],
        [ 5.0000,  8.8914, 15.8114, 28.1171],
        [ 6.0000, 10.6697, 18.9737, 33.7405],
        [ 7.0000, 12.4480, 22.1359, 39.3639]])

In [99]:
cos = theta.cos().repeat_interleave(2, dim=-1)
sin = theta.sin().repeat_interleave(2, dim=-1)

In [90]:
import math
math.cos(1.7783)

-0.20601776807921493

In [106]:
x = torch.randn(2, seq, d_k)
x.shape

torch.Size([2, 8, 8])

In [107]:
leading_dims = (1,) * (x.dim() - 2)
leading_dims

(1,)

In [110]:
cos = cos.view(*leading_dims, seq, d_k)
sin = sin.view(*leading_dims, seq, d_k)

In [112]:
r_x = rearrange(x, "... (s r) -> ... s r ", r=2)
r_x.shape

torch.Size([2, 8, 4, 2])

In [114]:
x_even, x_odd = r_x.unbind(dim=-1)

In [None]:
r_x = torch.stack((-x_odd, x_even), dim=-1)  # 旋转90度

In [117]:
r_x = rearrange(r_x, "... s r -> ... (s r)")

In [119]:
x

tensor([[[ 1.8282,  0.8732, -0.3041, -1.8611, -1.7955,  0.4487,  0.8303,
           0.4807],
         [-1.0876,  2.3761, -0.8377, -0.1925,  0.2339, -1.3503, -1.5107,
           0.1850],
         [-0.4693,  1.5095, -0.1448,  0.9863,  2.1509, -0.1755,  1.9013,
           0.2735],
         [ 0.8429, -0.4080, -0.2603, -0.5663, -0.9198, -0.8080,  2.0991,
           0.5454],
         [ 1.0636,  1.1429,  0.2439, -1.8725,  0.7116, -1.8701, -1.5315,
          -0.2919],
         [ 0.3703,  0.4410,  0.0990, -0.6092, -1.1373, -0.3860,  0.7922,
          -1.6737],
         [-1.2217, -0.3968,  0.2971,  0.0535,  0.1991, -1.2385, -0.2213,
          -0.9188],
         [ 1.0262,  0.6181,  1.3968,  0.2988,  0.2213, -0.0260,  1.1678,
           0.5342]],

        [[-1.6999, -0.4581, -0.7038,  1.1136, -1.4940,  0.8372, -0.6688,
           0.8712],
         [ 0.8199, -1.2803, -1.9211, -1.2290, -1.5066, -0.4937,  1.1617,
          -0.8774],
         [-0.3730, -0.7963,  0.1743, -0.6118, -1.0171,  0.2023, -2.0

In [121]:
r_x.shape

torch.Size([2, 8, 8])

In [122]:
x.shape

torch.Size([2, 8, 8])

In [123]:
result = einsum(r_x, x, "a d_x d_y, b d_y d_x -> a b")

In [124]:
result

tensor([[ 0.5235, -2.5713],
        [-1.5686,  1.2910]])

In [20]:
import torch

In [21]:
x = torch.tensor([[1, 2], [3, 4]])
x

tensor([[1, 2],
        [3, 4]])

In [22]:
y = torch.max(x, dim=1, keepdim=False)
y

torch.return_types.max(
values=tensor([2, 4]),
indices=tensor([1, 1]))

In [23]:
x = x - torch.max(x, dim=1, keepdim=True).values
x

tensor([[-1,  0],
        [-1,  0]])

In [24]:
x = torch.exp(x)
x

tensor([[0.3679, 1.0000],
        [0.3679, 1.0000]])

In [25]:
x / torch.sum(x, dim=1, keepdim=True)

tensor([[0.2689, 0.7311],
        [0.2689, 0.7311]])