In [73]:
import torch
from einops import rearrange, reduce, repeat
import torch.nn.functional as F
import numpy as np

<em>ops(x:Tensor, 'old -> new')</em>

Re-arranging elements according to a pattern

In [3]:
x = torch.randn(2, 3, 4)  # (batch, height, width)
print(x[0], end="\n\n")   # printing first batch 

y = rearrange(x, 'b h w -> b w h')  # Swap height and width
print(y[0], end="\n\n")

print("Before:", x.shape)
print("After:", y.shape)  # (2, 4, 3)

tensor([[-0.8691, -0.6780,  0.1807,  1.5269],
        [-2.2843,  0.0381,  0.0679, -0.2106],
        [ 0.6618,  3.1145, -0.4201, -0.2379]])

tensor([[-0.8691, -2.2843,  0.6618],
        [-0.6780,  0.0381,  3.1145],
        [ 0.1807,  0.0679, -0.4201],
        [ 1.5269, -0.2106, -0.2379]])

Before: torch.Size([2, 3, 4])
After: torch.Size([2, 4, 3])


For a larger tensor with `(n, h, w, p, q, c)` dimensions.

* n: batch size - Number of samples in the batch (sometimes b)
* h: height - Vertical spatial dimension (e.g., image height or feature map rows)
* w: width - Horizontal spatial dimension (e.g., image width or feature map columns)
* p: patch height or grid row - Often used for patch size or subdivisions of `h`
* q: patch width or grid column - Often used for patch size or subdivisions of w
* c:: channels - Number of channels (e.g., RGB = 3, or feature channels in CNNs)

For example: (8, 4, 4, 4, 4, 3) → batch of 8 images, split into 4x4 grids of 4x4 patches, with 3 channels

In [118]:
x = np.random.RandomState(42).normal(size=[10, 32, 100, 200])

x = torch.from_numpy(x)
x.requires_grad = True

x.shape # b c h w

torch.Size([10, 32, 100, 200])

For `tensor.permute` dimension indices <br><br>
Index - Dimension <br>
0 - b (batch size) <br>
1 - c (channels) <br>
2 - h (height) <br>
3 - w (width) <br>

In [58]:
# converting bchw to bhwc format and back is a common operation in CV
y_einops = rearrange(x, 'b c h w -> b h w c') # Using einops.rearrange
y_torch = x.permute(0, 2, 3, 1) # Using PyTorch's permute

print(y_einops.shape, y_torch.shape) 
assert torch.equal(y_einops, y_torch), "Not identical operation"

torch.Size([10, 100, 200, 32]) torch.Size([10, 100, 200, 32])


In [59]:
y0 = x                                    # torch.Size([10, 32, 100, 200])
y1 = reduce(y0, "b c h w -> b c", "max")  # torch.Size([10, 32]) - Apply global max pooling over spatial dimensions.
y2 = rearrange(y1, "b c -> c b")          # torch.Size([32, 10])
y3 = reduce(y2, "c b -> ", "sum")         # torch.Size([]) - Sums all values to produce a single scalar.
print(y3)

y3.backward()
print(reduce(x.grad, "b c h w -> ", "sum"))

tensor(1285.4242, dtype=torch.float64, grad_fn=<SumBackward1>)
tensor(320., dtype=torch.float64)


<b>Flattening</b> is common operation, frequently appears at the boundary between convolutional layers and fully connected layers

In [67]:
# 4D tensor [b, c, h, w] to a 2D tensor [b, c*h*w]
y_einops = rearrange(x, "b c h w -> b (c h w)")

# view() reshapes a tensor without copying memory (efficient).
# x.size(0) gives the batch size b.
# -1 tells PyTorch to infer the correct size for the second dimension (i.e., c * h * w).
y_torch0 = x.view(x.size(0), -1) 

# .reshape() is similar to .view(), but automatically makes a copy if needed.
# Can be slightly slower than .view() if it ends up copying memory.
y_torch1 = x.reshape(x.shape[0], -1)

# Flattens all dimensions starting from start_dim=1 (i.e., flattens [c, h, w]).
# Very clean and readable for flattening "everything after the batch".
y_torch2 = x.flatten(start_dim=1)

print(y_einops.shape, y_torch0.shape, y_torch1.shape, y_torch2.shape) 
assert torch.equal(y_einops, y_torch0), "Not identical operation"
assert torch.equal(y_einops, y_torch1), "Not identical operation"
assert torch.equal(y_einops, y_torch2), "Not identical operation"

torch.Size([10, 640000]) torch.Size([10, 640000]) torch.Size([10, 640000]) torch.Size([10, 640000])


<b>space-to-depth</b>

In [97]:
# Rearranges spatial/temporal blocks into the channel dimension.
# Used for downsampling without losing information (e.g. in real-time detection).

y_einpos = rearrange(x, "b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=2, w1=2)
print(y_einops.shape)

torch.Size([1, 3, 4, 4, 4])


<b>depth-to-space</b> (notice that it's reverse of the previous)

In [98]:
y_einpos = rearrange(x, "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=2, w1=2)
print(y_einops.shape)

torch.Size([1, 3, 4, 4, 4])


Simple <b>global average pooling</b>

In [71]:
# Averages each channel over the spatial dimensions (h, w).
# Equivalent to Global Average Pooling (GAP), common in CNN classification heads
y_einops = reduce(x, "b c h w -> b c", reduction="mean")

y_torch0 = x.mean((2,3)) # reduce over height and width

print(y_einops.shape, y_torch0.shape)
assert torch.equal(y_einops, y_torch0), "Not identical operation"

torch.Size([10, 32]) torch.Size([10, 32])


<b>max-pooling</b> with a kernel 2x2 (2D) - reduce over spatial (H×W)	

In [84]:
# Treats (h h1) and (w w1) as 2×2 blocks and takes max over each block.
# Equivalent to 2D MaxPooling with kernel_size=2 and stride=2.
y_einops = reduce(x,  "b c (h h1) (w w1) -> b c h w", reduction="max", h1=2, w1=2)
# same as the above, a 2×2 max pooling, with shorthand for specifying factor sizes directly.
# y_einops = reduce(x, "b c (h 2) (w 2) -> b c h w", reduction="max") 

y_torch0 = F.max_pool2d(x, kernel_size=2, stride=2) # 2d max pooling

print(y_einops.shape, y_torch0.shape)
assert torch.equal(y_einops, y_torch0), "Not identical operation"

torch.Size([10, 32, 50, 100]) torch.Size([10, 32, 50, 100])


<b>Temporal Max-Pooling (Batched Sequences)</b>

In [91]:
# for sequential 1-d models, you'll probably want pooling over time
# it applies temporal max-pooling over time dimension.
# it also assumes time has been reshaped as (t*2) — combines pairs of time steps and reduces to max.

# einops example
# reduce(x, '(t 2) b c -> t b c', reduction='max')

# pytorch equivalent
# t2, b, c = x.shape
# x = x.view(t2 // 2, 2, b, c)  # reshape to [t, 2, b, c]
# y = x.max(dim=1).values      # max over the 2-timestep window

# Temporal Max-Pooling (Batched Sequences) are used in sequence models (audio, video, nlp) by applying max operation 
# over time intervals (e.g. every 2 time steps)
# useful for downsampling temporal resolution (reducing length while keeping important features)

## using different x input example than the global one
x_t = torch.randn(8, 2, 4)  # (t*2, b, c) → (4*2, 2, 4)
y_t_einops = reduce(x_t, '(t 2) b c -> t b c', reduction='max')

x_reshaped = x_t.view(4, 2, 2, 4)  # [t, 2, b, c]
y_t_torch = x_reshaped.max(dim=1).values  # max over 2-timestep window

print(x_t.shape, y_t_einops.shape, y_t_torch.shape)
assert torch.equal(y_t_einops, y_t_torch), "Not identical operation"

torch.Size([8, 2, 4]) torch.Size([4, 2, 4]) torch.Size([4, 2, 4])


**3D Max-Pooling** - Reduce over volumetric blocks	

In [95]:
# for volumetric models, all three dimensions are pooled
# 3D max-pooling over non-overlapping 2x2x2 cubes (depth, height, width).
# Common in 3D vision (e.g., volumetric data, medical imaging).

# einops example
# reduce(x, 'b c (x 2) (y 2) (z 2) -> b c x y z', reduction='max')

# pytorch example
# y = F.max_pool3d(x, kernel_size=2, stride=2)

x_3d = torch.randn(1, 3, 8, 8, 8) # [B, C, D, H, W]
y_3d_einops = reduce(x_3d, 'b c (d 2) (h 2) (w 2) -> b c d h w', reduction='max')
y_3d_torch  = F.max_pool3d(x_3d, kernel_size=2, stride=2) 

print(x_3d.shape, y_3d_einops.shape, y_3d_torch.shape)
assert torch.equal(y_3d_einops, y_3d_torch), "Not identical operation"

torch.Size([1, 3, 8, 8, 8]) torch.Size([1, 3, 4, 4, 4]) torch.Size([1, 3, 4, 4, 4])


**Squeeze and unsqueeze (expand_dims)**

In [103]:
# models typically work only with batches,
# so to predict a single image ...

# took 1 image from the batch and re-arranged the axis
image = rearrange(x[0, :3], "c h w -> h w c") 
print(image.shape)

# ... create a dummy 1-element axis ...
y_einops = rearrange(image, "h w c -> () c h w")
print(y_einops.shape)

# ... imagine you predicted this with a convolutional network for classification,
# we'll just flatten axes ...
predictions = rearrange(y, "b c h w -> b (c h w)")
print(predictions.shape)

# ... finally, decompose (remove) dummy axis
predictions = rearrange(predictions, "() classes -> classes")
print(predictions.shape)

torch.Size([100, 200, 3])
torch.Size([1, 3, 100, 200])
torch.Size([1, 60000])
torch.Size([60000])


**keepdims-like behavior for reductions** <br>
* empty composition () provides dimensions of length 1, which are broadcastable. <br>
* alternatively, you can use just 1 to introduce new axis, that's a synonym to ()

In [107]:
print(x.shape, reduce(x, "b c h w -> b c 1 1", "mean").shape)
assert torch.equal(reduce(x, "b c h w -> b c 1 1", "mean"), reduce(x, "b c h w -> b c () ()", "mean")), "Not identical operation"

torch.Size([10, 32, 100, 200]) torch.Size([10, 32, 1, 1])


In [109]:
# per-channel mean-normalization for each image:
y_einops = x - reduce(x, "b c h w -> b c 1 1", "mean")
print(y_einops.shape)

# per-channel mean-normalization for whole batch:
y_einops = x - reduce(x, "b c h w -> 1 c 1 1", "mean")
print(y_einops.shape)

torch.Size([10, 32, 100, 200])
torch.Size([10, 32, 100, 200])


**Stacking**

In [120]:
print(x.shape)
list_of_tensors = list(x)

torch.Size([10, 32, 100, 200])


In [112]:
# New axis (one that enumerates tensors) appears first on the left side of expression. Just as if we were indexing list 
#first we'd get tensor by index
tensors = rearrange(list_of_tensors, "b c h w -> b h w c")
print(tensors.shape)

torch.Size([10, 100, 200, 32])


In [113]:
# or maybe stack along last dimension?
tensors = rearrange(list_of_tensors, "b c h w -> h w c b")
print(tensors.shape)

torch.Size([100, 200, 32, 10])


**Concatenation**

In [115]:
# concatenate over the first dimension?
tensors = rearrange(list_of_tensors, "b c h w -> (b h) w c")
print(tensors.shape)

torch.Size([1000, 200, 32])


In [116]:
# or maybe concatenate along last dimension?
tensors = rearrange(list_of_tensors, "b c h w -> h w (b c)")
print(tensors.shape)

torch.Size([100, 200, 320])


**Shuffling within a dimension**

In [124]:
# channel shuffle (as it is drawn in shufflenet paper)
y_einops = rearrange(x, "b (g1 g2 c) h w-> b (g2 g1 c) h w", g1=4, g2=4)
print(y_einops.shape)

# simpler version of channel shuffle
y_einops = rearrange(x, "b (g c) h w-> b (c g) h w", g=4)
print(y_einops.shape)

torch.Size([10, 32, 100, 200])
torch.Size([10, 32, 100, 200])
