In [57]:
# install dependencies
#!apt update && apt install python-pydot python-pydot-ng graphviz -y
#!pip install graphviz
#!pip install torchviz
#!pip install torchsummary

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1


In [2]:
import torch
from torch import nn 
import torchvision.models

from torchviz import make_dot
from torchsummary import summary

In [3]:
from Models.Perceiver_archs.perceiver_module import (
    PerceiverEncoder, PerceiverDecoder, PerceiverIO
)
from Models.Perceiver_archs.task_spec_adapter import (
    Image_FourierEnc, ClassificationOutputAdapter
)

# Fourier-encodes pixel positions and flattens along spatial dimensions
input_adapter = Image_FourierEnc(
    image_shape=(224, 224, 3),  # M = 224 * 224
    num_frequency_bands=64,
)
# Projects generic Perceiver decoder output to specified number of classes
output_adapter = ClassificationOutputAdapter(
    num_classes=1000,  # E
    num_output_query_channels=1024  # F
)

# Generic Perceiver encoder
encoder = PerceiverEncoder(
    input_adapter=input_adapter,
    num_latents=512,  # N
    num_latent_channels=1024,  # D
    num_cross_attention_qk_channels=input_adapter.num_input_channels,  # C
    num_cross_attention_heads=1,
    num_self_attention_heads=2,
    num_self_attention_layers_per_block=2,
    num_self_attention_blocks=3,
    dropout=0.0,
)
# Generic Perceiver decoder
decoder = PerceiverDecoder(
    output_adapter=output_adapter,
    num_latent_channels=1024,  # D
    num_cross_attention_heads=1,
    dropout=0.0,
)
# Perceiver IO image classifier
model = PerceiverIO(encoder, decoder)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [4]:
img_shape = (224, 224, 3)

# arch summary 
model = model.cuda()
summary(model, img_shape)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
  Image_FourierEnc-1           [-1, 50176, 261]               0
         LayerNorm-2            [-1, 512, 1024]           2,048
         LayerNorm-3           [-1, 50176, 261]             522
            Linear-4             [-1, 512, 261]         267,525
            Linear-5           [-1, 50176, 261]          68,382
            Linear-6           [-1, 50176, 261]          68,382
           Dropout-7           [-1, 512, 50176]               0
            Linear-8            [-1, 512, 1024]         268,288
MultiHeadAttention-9            [-1, 512, 1024]               0
   CrossAttention-10            [-1, 512, 1024]               0
          Dropout-11            [-1, 512, 1024]               0
         Residual-12            [-1, 512, 1024]               0
        LayerNorm-13            [-1, 512, 1024]           2,048
           Linear-14            [-1, 51

In [5]:
## Visualization
# do not forgot to place in cpu..
model = model.cpu()

# build pipeline :
x = torch.randn((1, 224, 224, 3))
#enc_emb = encoder(x)
#out = decoder(enc_emb)
out = model(x)

make_dot(out, params=dict(model.named_parameters()), show_attrs=False, show_saved=False).render("./Models/example", format='png')

'Models/example.png'