In [1]:
import sys; sys.path.append('..')
import torchinfo
import torch

## Language Models

In [2]:
from language_models import DualAttnTransformerLM

dat_lm = DualAttnTransformerLM(
    vocab_size=32_000,    # vocabulary size
    d_model=512,          # model dimension
    n_layers=6,           # number of layers
    n_heads_sa=4,         # number of self-attention heads
    n_heads_ra=4,         # number of relational attention headsd
    dff=2048,             # feedforward intermediate dimension
    dropout_rate=0.1,     # dropout rate
    activation='swiglu',  # activation function of feedforward block
    norm_first=True,      # whether to use pre-norm or post-norm
    max_block_size=1024,  # max context length
    symbol_retrieval='symbolic_attention', # type of symbol assignment mechanism
    symbol_retrieval_kwargs=dict(d_model=512, n_heads=4, n_symbols=512), # kwargs for symbol assignment mechanism
    pos_enc_type='RoPE'   # type of positional encoding to use
)

torchinfo.summary(dat_lm)

Layer (type:depth-idx)                                  Param #
DualAttnTransformerLM                                   --
├─ModuleDict: 1-1                                       --
│    └─Embedding: 2-1                                   16,384,000
│    │    └─Linear: 3-1                                 16,416,000
│    └─Dropout: 2-2                                     --
│    └─SymbolicAttention: 2-3                           524,288
│    │    └─Linear: 3-2                                 262,656
│    └─ModuleList: 2-4                                  --
│    │    └─DualAttnEncoderBlock: 3-3                   4,595,200
│    │    └─DualAttnEncoderBlock: 3-4                   4,595,200
│    │    └─DualAttnEncoderBlock: 3-5                   4,595,200
│    │    └─DualAttnEncoderBlock: 3-6                   4,595,200
│    │    └─DualAttnEncoderBlock: 3-7                   4,595,200
│    │    └─DualAttnEncoderBlock: 3-8                   4,595,200
│    └─Linear: 2-5                        

In [3]:
idx = torch.randint(0, 32_000, (1, 129))
x, y = idx[:, :-1], idx[:, 1:]
logits, loss = dat_lm(x, y)
logits.shape # shape: (1, 128, 32000)

torch.Size([1, 128, 32000])

## Vision Models

In [4]:
from vision_models import VisionDualAttnTransformer

img_shape = (3, 224, 224)
patch_size = (16, 16)
n_patches = (img_shape[1] // patch_size[0]) * (img_shape[2] // patch_size[1])

dat_vision = VisionDualAttnTransformer(
    image_shape=img_shape,     # shape of input image
    patch_size=patch_size,     # size of patch
    num_classes=1000,          # number of classes
    d_model=512,               # model dimension
    n_layers=6,                # number of layers
    n_heads_sa=4,              # number of self-attention heads
    n_heads_ra=4,              # number of relational attention heads
    dff=2048,                  # feedforward intermediate dimension
    dropout_rate=0.1,          # dropout rate
    activation='swiglu',       # activation function of feedforward block
    norm_first=True,           # whether to use pre-norm or post-norm
    symbol_retrieval='position_relative',
    symbol_retrieval_kwargs=dict(symbol_dim=512, max_rel_pos=n_patches+1),
    ra_kwargs=dict(symmetric_rels=True, use_relative_positional_symbols=True),
    pool='cls',                # type of pooling (class token)
)

torchinfo.summary(dat_vision)

Layer (type:depth-idx)                             Param #
VisionDualAttnTransformer                          101,376
├─PositionRelativeSymbolRetriever: 1-1             --
│    └─RelativePositionalEncoding: 2-1             202,240
├─Sequential: 1-2                                  --
│    └─Rearrange: 2-2                              --
│    └─LayerNorm: 2-3                              1,536
│    └─Linear: 2-4                                 393,728
│    └─LayerNorm: 2-5                              1,024
├─Dropout: 1-3                                     --
├─ModuleList: 1-4                                  --
│    └─DualAttnEncoderBlock: 2-6                   --
│    │    └─Dropout: 3-1                           --
│    │    └─LayerNorm: 3-2                         1,024
│    │    └─DualAttention: 3-3                     1,180,672
│    │    └─LayerNorm: 3-4                         1,024
│    │    └─FeedForwardBlock: 3-5                  3,150,336
│    └─DualAttnEncoderBlock: 2-7    

In [5]:
img = torch.randn(1, *img_shape)
logits = dat_vision(img)
logits.shape # shape: (1, 1000)

torch.Size([1, 1000])