In [None]:
!pip install -q einops

[?25l[K     |███████▉                        | 10 kB 18.5 MB/s eta 0:00:01[K     |███████████████▊                | 20 kB 7.1 MB/s eta 0:00:01[K     |███████████████████████▋        | 30 kB 3.5 MB/s eta 0:00:01[K     |███████████████████████████████▌| 40 kB 3.9 MB/s eta 0:00:01[K     |████████████████████████████████| 41 kB 122 kB/s 
[?25h

In [None]:
import einops
import torch
import numpy as np

import torch.nn as nn

In [None]:
x = torch.randn(size=(1, 3, 85, 13, 13))

x.shape

torch.Size([1, 3, 85, 13, 13])

In [None]:
# I would like the shape to be
# (1, 3, 13, 13, 85)

In [None]:
y = x.permute(0, 1, 3, 4, 2)
y.shape

torch.Size([1, 3, 13, 13, 85])

In [None]:
y1 = einops.rearrange(x, "b num_anchors p h w -> b num_anchors h w p")
y1.shape

torch.Size([1, 3, 13, 13, 85])

In [None]:
# I would like the shape to be
# (1, 3 * 13 * 13, 85)

In [None]:
y2 = einops.rearrange(x, "b num_anchors p h w -> b (num_anchors h w) p")
y2.shape

torch.Size([1, 507, 85])

In [None]:
# I would like the reshape
# from (1, 507, 85)
# to (1, 3, 13, 13, 85)

In [None]:
x2 = torch.randn(size=(1, 507, 85))

x2.shape

torch.Size([1, 507, 85])

In [None]:
y3 = einops.rearrange(
    x2,
    "b (num_anchors h w) p -> b num_anchors h w p",
    num_anchors=3, h=13, w=13)

y3.shape

torch.Size([1, 3, 13, 13, 85])

In [None]:
y3 = einops.rearrange(
    x2,
    "b (num_anchors h w) p -> b num_anchors h w p",
    num_anchors=3, h=13)

y3.shape

torch.Size([1, 3, 13, 13, 85])

In [None]:
# let's change the code snippet I showed
# to you earlier for the sake of completion of
# this concept

In [None]:
pred = torch.randn(size=(1, 3 * 85, 13, 13))

pred_reshaped = pred.permute(0, 2, 3, 1).contiguous().view(1, -1, 85)

pred.shape, pred_reshaped.shape

(torch.Size([1, 255, 13, 13]), torch.Size([1, 507, 85]))

In [None]:
ein_pred = einops.rearrange(pred,
                            "b (num_anchors p) h w -> b (num_anchors h w) p",
                            num_anchors=3,
                            h=13,
                            w=13)

ein_pred.shape

torch.Size([1, 507, 85])

In [None]:
# let's handle the ugly reshape

w = 13
h = 13

t = torch.arange(w, dtype=torch.float32)

t.shape

torch.Size([13])

In [None]:
ugly_c_x = t.reshape(1, 1, -1, 1)

ugly_c_x.shape

torch.Size([1, 1, 13, 1])

In [None]:
nice_c_x = einops.rearrange(t, "w -> 1 1 w 1")

nice_c_x.shape

torch.Size([1, 1, 13, 1])

In [None]:
torch.allclose(ugly_c_x, nice_c_x)

True

In [None]:
nice_c_x.shape

torch.Size([1, 1, 13, 1])

In [None]:
torch.squeeze(nice_c_x).shape

torch.Size([13])

In [None]:
torch.squeeze(nice_c_x, 0).shape

torch.Size([1, 13, 1])

In [None]:
einops.rearrange(nice_c_x, "1 1 w 1 -> 1 w").shape

torch.Size([1, 13])

In [None]:
# The Rarrange pytorch layer
from einops.layers.torch import Rearrange

In [None]:
class ANeuralNetwork(nn.Module):
  def __init__(self,
               in_channels:int,
               num_anchors_per_cell:int,
               num_classes:int):

    super().__init__()

    num_predicted_channels = num_anchors_per_cell * (4 + 1 + num_classes)

    self.conv = nn.Conv2d(
        in_channels=in_channels,
        out_channels=num_predicted_channels,
        kernel_size=1,
        stride=1,
    )

    self.rearrange = Rearrange("b (num_anchors_per_cell p) h w -> b num_anchors_per_cell h w p",
                               num_anchors_per_cell=num_anchors_per_cell)


  def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self.conv(x)
    x = self.rearrange(x)
    return x


In [None]:
net = ANeuralNetwork(in_channels=512, num_anchors_per_cell=3, num_classes=80)

input_x = torch.randn(size=(1, 512, 13, 13))

output = net(input_x)

output.shape

torch.Size([1, 3, 13, 13, 85])