#### Simple Vision Transformer

In [4]:
import os, sys, torch
sys.path.append(os.path.join(os.path.abspath(''), '../../'))

from peekvit.models.vit import VisionTransformer

vit = VisionTransformer(
    image_size=224,
    patch_size=16,
    num_classes=1000,
    hidden_dim=768,
    num_layers=4,
    num_class_tokens=1,
    num_heads=12,
    mlp_dim=3072,
    dropout=0.1)

print(vit) 


VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): ViTEncoder(
    (dropout): Dropout(p=0.1, inplace=False)
    (layers): Sequential(
      (0): ViTBlock(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (self_attention): SelfAttention(
          (self_attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
        )
        (dropout): Dropout(p=0.1, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (1): ViTBlock(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (self_attention): SelfAttention(
          (self_attention): MultiheadAttention(
            (out_proj): NonDynami

In [9]:
x = torch.randn(1, 3, 224, 224)
logits = vit(x)
print(logits.shape)

torch.Size([1, 1000])


#### Pretrained ViT from torch or timm weights 

Make sure the dimensions of the transformer match with the weights you download

In [14]:
from torchvision.models.vision_transformer import ViT_B_16_Weights

# see https://github.com/pytorch/vision/blob/806dba678d5b01f6e8a46f7c48fdf8c09369a267/torchvision/models/vision_transformer.py#L351
# for timm any pretrained weights are supported

vit = VisionTransformer(
    image_size=224,
    patch_size=16,
    num_classes=1000,
    hidden_dim=768,
    num_layers=12,
    num_class_tokens=1,
    num_heads=12,
    mlp_dim=3072,
    dropout=0.1,
    #torch_pretrained_weights=str(ViT_B_16_Weights['IMAGENET1K_V1']),
    #timm_pretrained_weights=['facebookresearch/deit:main', 'deit_small_patch16_224']
    )



You can also provide local paths to torch or timm pretrained weights.

#### Load state from checkpoint 

During training checkpoints are saved automatically, you can load them like so:

In [6]:
from peekvit.utils.utils import load_state

model, optimizer, epoch, model_args, noise_args = load_state(path='../runs/2024-01-24-10-30-25/checkpoints/epoch_000.pth', model=None)

Creating model based on saved state
VisionTransformer
{'_target_': 'peekvit.models.vit.VisionTransformer', 'image_size': 160, 'patch_size': 8, 'num_classes': 10, 'hidden_dim': 256, 'mlp_dim': 768, 'num_layers': 4, 'num_heads': 4}


In [7]:
model_args

{'image_size': 160,
 'patch_size': 8,
 'num_classes': 10,
 'hidden_dim': 256,
 'mlp_dim': 768,
 'num_layers': 4,
 'num_heads': 4}

#### Load pretrained Vit params into other Vit 

You can also initialize a vit with a different architecture from pre-trained weights, as long as dimensions match. If the model has different head dim or additional paramters, they will be randomly initalized.

In [10]:
from peekvit.models.rankvit import RankVisionTransformer

rankvit = RankVisionTransformer(
    image_size = 160, 
    patch_size  = 8, 
    num_classes  = 10, 
    hidden_dim = 256, 
    mlp_dim = 768, 
    num_layers = 4, 
    num_heads = 4
    )

rankvit, model_args, *_ = load_state(path='../runs/2024-01-24-10-30-25/checkpoints/epoch_100.pth', model=rankvit)
rankvit

RankVisionTransformer(
  (conv_proj): Conv2d(3, 256, kernel_size=(8, 8), stride=(8, 8))
  (encoder): ViTEncoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (0): RankViTBlock(
        (ln_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (self_attention): SelfAttention(
          (self_attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
          )
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (fc1): Linear(in_features=256, out_features=768, bias=True)
          (fc2): Linear(in_features=768, out_features=256, bias=True)
        )
      )
      (1): RankViTBlock(
        (ln_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (self_attention): SelfAttention(
          (self_attention): MultiheadAttention(
            (out_proj): Non