In [11]:
import einx
import torch
import numpy as np
from torch import nn
from jaxtyping import Float
from einops import rearrange, einsum, reduce


x: Float[torch.Tensor, "batch seq1 hidden"] = torch.ones(2, 3, 4) # inspect x
y: Float[torch.Tensor, "batch seq2 hidden"] = torch.ones(2, 3, 4) # inspect y
print(x)
print(y)
## Basic implementation
z = x @ y.transpose(-2, -1)
print(z)
## Hard to tell the input and output shapes and what they mean.
## What shapes can x and y have, and do any of these have unexpected behavior?

x: Float[torch.Tensor, "batch seq1 hidden"] = torch.ones(2, 3, 4) # inspect x
y: Float[torch.Tensor, "batch seq2 hidden"] = torch.ones(2, 3, 4) # inspect y

# Einsum is self-documenting and robust
#                       x                   y        ->         z
z = einsum(x, y, "batch seq1 hidden, batch seq2 hidden -> batch seq1 seq2")
print(z)

## Or, a batched version where D can have any leading dimensions but AS is constrained.
z = einsum(x, y, "... seq1 hidden, ... seq2 hidden -> ... seq1 seq2")
print(z)


tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]])
tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]])
tensor([[[4., 4., 4.],
         [4., 4., 4.],
         [4., 4., 4.]],

        [[4., 4., 4.],
         [4., 4., 4.],
         [4., 4., 4.]]])
tensor([[[4., 4., 4.],
         [4., 4., 4.],
         [4., 4., 4.]],

        [[4., 4., 4.],
         [4., 4., 4.],
         [4., 4., 4.]]])
tensor([[[4., 4., 4.],
         [4., 4., 4.],
         [4., 4., 4.]],

        [[4., 4., 4.],
         [4., 4., 4.],
         [4., 4., 4.]]])


In [14]:
images = torch.randn(64, 128, 128, 3) # (batch, height, width, channel)
                                                                # (64, 128, 128, 3) (b, H , W, C)
dim_by = torch.linspace(start=0.0, end=1.0, steps=10)

## Reshape and multiply
dim_value = rearrange(dim_by, "dim_value           -> 1 dim_value 1 1 1") # (1, 10, 1, 1)
images_rearr = rearrange(images, "b height width channel -> b 1 height width channel") # resize to (64, 1, 128, 128, 3) => (B, 1, H, W)
dimmed_images = images_rearr * dim_value        # (64, 10, 128, 128, 3)

## or in one go
"""
dimmed_images = einsum(images, dim_by, "batch height width channel, dim_value -> batch dim_value height width channel")
"""

'\ndimmed_images = einsum(images, dim_by, "batch height width channel, dim_value -> batch dim_value height width channel")\n'

In [15]:
channels_last = torch.randn(64, 32, 32, 3) #(B, H, W, C)
B = torch.randn(32 * 32, 32 * 32)
#old way
## rearrange an image tensor for mixing across all pixels
channels_last_flat = channels_last.view(
    -1, channels_last.size(1) * channels_last.size(2), channels_last.size(3)
)
channels_first_flat = channels_last_flat.transpose(1, 2)
channels_first_flat_transformed = channels_first_flat @ B.T
channels_last_flat_transformed = channels_first_flat_transformed.transpose(1, 2)
channels_last_transformed = channels_last_flat_transformed.view(*channels_last.shape)




In [17]:
channels_last = torch.randn(64, 32, 32, 3) #(B, H, W, C)
B = torch.randn(32 * 32, 32 * 32)
height = width = 32
## Rearrange replaces clunky torch view + transpose
channels_first = rearrange(
    channels_last,
    "batch height width channel -> batch channel (height width)"
)
channels_first_transformed = einsum(
    channels_first, B,
    "batch channel pixel_in, pixel_out pixel_in -> batch channel pixel_out"
)
channels_last_transformed = rearrange(
    channels_first_transformed,
    "batch channel (height width) -> batch height width channel",
    height=height, width=width
)

In [None]:
height = width = 32
channels_last_transformed = einx.dot(
    "batch row_in col_in channel, (row_out col_out) (row_in col_in)"
    "-> batch row_out col_out channel",
    channels_last, B,
    col_in=width, col_out=width
)