In [2]:
from einops import rearrange, reduce
import numpy as np
from utils import guess

x = np.random.RandomState(42).normal(size=[10, 32, 100, 200])

In [3]:
# select one from 'chainer', 'tensorflow', 'pytorch' 
flavour = 'pytorch'

In [4]:
print('selected {} backend'.format(flavour))
if flavour == 'tensorflow':
    import tensorflow as tf
    tape = tf.GradientTape(persistent=True)
    tape.__enter__()
    x = tf.Variable(x) + 0
elif flavour == 'pytorch':
    import torch
    x = torch.from_numpy(x)
    x.requires_grad = True
else:
    assert flavour == 'chainer'
    import chainer
    x = chainer.Variable(x)

selected pytorch backend


In [5]:
type(x), x.shape

(torch.Tensor, torch.Size([10, 32, 100, 200]))

In [6]:
y = rearrange(x, 'b c h w -> b h w c')
guess(y.shape) # 10 100 200 32

In [7]:
y0 = x
y1 = reduce(y0, 'b c h w -> b c', 'max')
y2 = rearrange(y1, 'b c -> c b')
y3 = reduce(y2, 'c b -> ', 'sum')


y3.backward()
print(reduce(x.grad, 'b c h w -> ', 'sum'))

tensor(320., dtype=torch.float64)


In [10]:
from einops import asnumpy
y3_numpy = asnumpy(y3)

print(type(y3_numpy))

<class 'numpy.ndarray'>


In [13]:
x.shape

torch.Size([10, 32, 100, 200])

Flattening

In [14]:
y = rearrange(x, 'b c h w -> b (c h w)')
guess(y.shape) # 10 640000

space-to-depth

In [15]:
y = rearrange(x, 'b c (h h1) (w w1) -> b (h1 w1 c) h w', h1=2, w1=2)
guess(y.shape)

depth-to-space

In [16]:
y = rearrange(x, 'b (h1 w1 c) h w -> b c (h h1) (w w1)', h1=2, w1=2)
guess(y.shape)

max-pooling

In [17]:
# you can skip names for reduced axes
y = reduce(x, 'b c (h 2) (w 2) -> b c h w', reduction='max')
guess(y.shape)

In [19]:
x[0].shape

torch.Size([32, 100, 200])

In [22]:
# models typically work only with batches, 
# so to predict a single image ...
image = rearrange(x[0, :3], 'c h w -> h w c')
# ... create a dummy 1-element axis ...
y = rearrange(image, 'h w c -> () c h w')
# ... imagine you predicted this with a convolutional network for classification,
# we'll just flatten axes ...
predictions = rearrange(y, 'b c h w -> b (c h w)')
# ... finally, decompose (remove) dummy axis
predictions = rearrange(predictions, '() classes -> classes')