In [7]:
import torch
from einops import rearrange, einsum


D = torch.randn(10, 128)
A = torch.randn(10, 128)

## Basic implementation
# Hard to tell the input and output shapes and what they mean.
# What shapes can D and A have, and do any of these have unexpected behavior?
Y1 = D @ A.T
print("Y1.shape", Y1.shape)

## Einsum is self-documenting and robust
# DA->Y
Y2 = einsum(D, A, "batch sequence d_in, d_out d_in -> batch sequence d_out")
print("Y2.shape", Y2.shape)
BD = torch.randn(1024, 10, 100)

## Or, a batched version where D can have any leading dimensions but A is constrained.
Y3 = einsum(BD, A, "... d_in, d_out d_in -> ... d_out")
print("Y3.shape", Y3.shape)

torch.Size([10, 10])


RuntimeError: einsum(): the number of subscripts in the equation (3) does not match the number of dimensions (2) for operand 0 and no ellipsis was given

In [None]:
    #longest_token.decode('utf-8', errors='rimages = torch.randn(64, 128, 128, 3) # (batch, height, width, channel)
dim_by = torch.linspace(start=0.0, end=1.0, steps=10)
## Reshape and multiply
dim_value = rearrange(dim_by, "dim_value-> 1 dim_value 1 1 1")
images_rearr = rearrange(images, "b height width channel -> b 1 height width channel")
dimmed_images = images_rearr * dim_value

## Or in one go:
dimmed_images = einsum(images, dim_by, "batch height width channel, dim_value -> batch dim_value height width channel")