In [4]:
import torch
from perceiver_pytorch import Perceiver

model = Perceiver(
    input_channels = 3,          # number of channels for each token of the input
    input_axis = 2,              # number of axis for input data (2 for images, 3 for video)
    num_freq_bands = 6,          # number of freq bands, with original value (2 * K + 1)
    max_freq = 10.,              # maximum frequency, hyperparameter depending on how fine the data is
    depth = 6,                   # depth of net. The shape of the final attention mechanism will be:
                                 #   depth * (cross attention -> self_per_cross_attn * self attention)
    num_latents = 256,           # number of latents, or induced set points, or centroids. different papers giving it different names
    latent_dim = 512,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 8,            # number of heads for latent self attention, 8
    cross_dim_head = 64,         # number of dimensions per cross attention head
    latent_dim_head = 64,        # number of dimensions per latent self attention head
    num_classes = 1000,          # output number of classes
    attn_dropout = 0.,
    ff_dropout = 0.,
    weight_tie_layers = False   # whether to weight tie layers (optional, as indicated in the diagram)
    #fourier_encode_data = True,  # whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself
    #self_per_cross_attn = 2      # number of self attention blocks per cross attention
)

img = torch.randn(1, 224, 224, 3) # 1 imagenet image, pixelized

model(img)


tensor([[-2.8210e-01,  2.9318e-01,  7.0946e-01, -8.8673e-01,  4.2799e-01,
         -2.7094e-01,  1.3906e-01,  8.2312e-01, -8.8990e-01, -4.1665e-01,
         -1.9299e-01,  5.5300e-01,  1.6493e-01,  4.4678e-01, -5.7664e-01,
          3.8312e-01,  5.8950e-01, -2.0961e-01, -1.6305e-02,  2.8482e-01,
         -3.0858e-01, -1.2850e-01, -2.1587e-02, -1.2554e-01, -6.3441e-01,
          7.3826e-01,  1.6357e-01,  2.1592e-02, -1.6491e-01, -5.9727e-02,
          3.1662e-01, -2.6719e-01,  4.6839e-01,  3.4103e-02,  3.5765e-01,
          1.3757e-01,  2.7838e-01, -8.2405e-02, -3.9065e-01,  9.8651e-01,
          8.5919e-01,  4.2573e-01, -2.3971e-01,  1.0989e+00, -2.8762e-01,
         -1.3976e-01, -7.9341e-01,  1.0641e+00,  6.3552e-01,  9.3253e-02,
         -2.5499e-01,  2.4313e-01, -6.2197e-01, -9.3680e-01, -8.6149e-01,
          5.3014e-01, -1.5029e+00,  5.0285e-01,  2.2240e-01,  5.3948e-01,
          9.6929e-01,  1.4928e+00, -1.2208e+00, -1.2596e-01, -7.8990e-01,
          9.0841e-01, -7.2075e-01,  3.

In [5]:
model(img).shape

torch.Size([1, 1000])