#### Simple Vision Transformer

In [1]:
import os, sys, torch
sys.path.append(os.path.join(os.path.abspath(''), '../../'))

from AdaViT.models.vit import VisionTransformer

vit = VisionTransformer(
    image_size=224,
    patch_size=16,
    num_classes=1000,
    hidden_dim=768,
    num_layers=4,
    num_class_tokens=1,
    num_heads=12,
    mlp_dim=3072,
    dropout=0.1)

print(vit) 


VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): ViTEncoder(
    (dropout): Dropout(p=0.1, inplace=False)
    (layers): Sequential(
      (0): ViTBlock(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (self_attention): SelfAttention(
          (self_attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
        )
        (dropout): Dropout(p=0.1, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (1): ViTBlock(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (self_attention): SelfAttention(
          (self_attention): MultiheadAttention(
            (out_proj): NonDynami

In [2]:
!pip install plotly



In [3]:
from AdaViT.utils.visualize import plot_distribution
import matplotlib.pyplot as plt

target_dist = torch.distributions.StudentT(loc=7, scale=0.25, df=5)
plot_distribution(target_dist)

AttributeError: 'StudentT' object has no attribute 'prob'

In [2]:
x = torch.randn(1, 3, 224, 224)
logits = vit(x)
print(logits.shape)

torch.Size([1, 1000])


#### Pretrained ViT from torch or timm weights 

Make sure the dimensions of the transformer match with the weights you download

In [3]:
from torchvision.models.vision_transformer import ViT_B_16_Weights

# see https://github.com/pytorch/vision/blob/806dba678d5b01f6e8a46f7c48fdf8c09369a267/torchvision/models/vision_transformer.py#L351
# for timm any pretrained weights are supported

vit = VisionTransformer(
    image_size=224,
    patch_size=16,
    num_classes=10,
    hidden_dim=192,
    num_layers=12,
    num_class_tokens=1,
    num_heads=3,
    mlp_dim=768,
    dropout=0.1,
    #torch_pretrained_weights=str(ViT_B_16_Weights['IMAGENET1K_V1']),
    timm_pretrained_weights=['facebookresearch/deit:main', 'deit_small_patch16_224']
    )

Loading timm pretrained weights:  ['facebookresearch/deit:main', 'deit_small_patch16_224']
Downloading timm pretrained weights:  ['facebookresearch/deit:main', 'deit_small_patch16_224']


Using cache found in /home/studio-lab-user/.cache/torch/hub/facebookresearch_deit_main
  def deit_tiny_patch16_224(pretrained=False, **kwargs):
  def deit_small_patch16_224(pretrained=False, **kwargs):
  def deit_base_patch16_224(pretrained=False, **kwargs):
  def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
  def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
  def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
  def deit_base_patch16_384(pretrained=False, **kwargs):
  def deit_base_distilled_patch16_384(pretrained=False, **kwargs):


Loading weights for a different number of classes. Replacing head with random weights. You should fine-tune the model.


RuntimeError: Error(s) in loading state_dict for VisionTransformer:
	size mismatch for class_tokens: copying a param with shape torch.Size([1, 1, 384]) from checkpoint, the shape in current model is torch.Size([1, 1, 192]).
	size mismatch for conv_proj.weight: copying a param with shape torch.Size([384, 3, 16, 16]) from checkpoint, the shape in current model is torch.Size([192, 3, 16, 16]).
	size mismatch for conv_proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.pos_embedding: copying a param with shape torch.Size([1, 197, 384]) from checkpoint, the shape in current model is torch.Size([1, 197, 192]).
	size mismatch for encoder.layers.0.ln_1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.0.ln_1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.0.self_attention.self_attention.in_proj_weight: copying a param with shape torch.Size([1152, 384]) from checkpoint, the shape in current model is torch.Size([576, 192]).
	size mismatch for encoder.layers.0.self_attention.self_attention.in_proj_bias: copying a param with shape torch.Size([1152]) from checkpoint, the shape in current model is torch.Size([576]).
	size mismatch for encoder.layers.0.self_attention.self_attention.out_proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([192, 192]).
	size mismatch for encoder.layers.0.self_attention.self_attention.out_proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.0.ln_2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.0.ln_2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.0.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 192]).
	size mismatch for encoder.layers.0.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.layers.0.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([192, 768]).
	size mismatch for encoder.layers.0.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.1.ln_1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.1.ln_1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.1.self_attention.self_attention.in_proj_weight: copying a param with shape torch.Size([1152, 384]) from checkpoint, the shape in current model is torch.Size([576, 192]).
	size mismatch for encoder.layers.1.self_attention.self_attention.in_proj_bias: copying a param with shape torch.Size([1152]) from checkpoint, the shape in current model is torch.Size([576]).
	size mismatch for encoder.layers.1.self_attention.self_attention.out_proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([192, 192]).
	size mismatch for encoder.layers.1.self_attention.self_attention.out_proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.1.ln_2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.1.ln_2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.1.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 192]).
	size mismatch for encoder.layers.1.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.layers.1.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([192, 768]).
	size mismatch for encoder.layers.1.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.2.ln_1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.2.ln_1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.2.self_attention.self_attention.in_proj_weight: copying a param with shape torch.Size([1152, 384]) from checkpoint, the shape in current model is torch.Size([576, 192]).
	size mismatch for encoder.layers.2.self_attention.self_attention.in_proj_bias: copying a param with shape torch.Size([1152]) from checkpoint, the shape in current model is torch.Size([576]).
	size mismatch for encoder.layers.2.self_attention.self_attention.out_proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([192, 192]).
	size mismatch for encoder.layers.2.self_attention.self_attention.out_proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.2.ln_2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.2.ln_2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.2.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 192]).
	size mismatch for encoder.layers.2.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.layers.2.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([192, 768]).
	size mismatch for encoder.layers.2.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.3.ln_1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.3.ln_1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.3.self_attention.self_attention.in_proj_weight: copying a param with shape torch.Size([1152, 384]) from checkpoint, the shape in current model is torch.Size([576, 192]).
	size mismatch for encoder.layers.3.self_attention.self_attention.in_proj_bias: copying a param with shape torch.Size([1152]) from checkpoint, the shape in current model is torch.Size([576]).
	size mismatch for encoder.layers.3.self_attention.self_attention.out_proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([192, 192]).
	size mismatch for encoder.layers.3.self_attention.self_attention.out_proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.3.ln_2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.3.ln_2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.3.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 192]).
	size mismatch for encoder.layers.3.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.layers.3.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([192, 768]).
	size mismatch for encoder.layers.3.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.4.ln_1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.4.ln_1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.4.self_attention.self_attention.in_proj_weight: copying a param with shape torch.Size([1152, 384]) from checkpoint, the shape in current model is torch.Size([576, 192]).
	size mismatch for encoder.layers.4.self_attention.self_attention.in_proj_bias: copying a param with shape torch.Size([1152]) from checkpoint, the shape in current model is torch.Size([576]).
	size mismatch for encoder.layers.4.self_attention.self_attention.out_proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([192, 192]).
	size mismatch for encoder.layers.4.self_attention.self_attention.out_proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.4.ln_2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.4.ln_2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.4.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 192]).
	size mismatch for encoder.layers.4.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.layers.4.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([192, 768]).
	size mismatch for encoder.layers.4.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.5.ln_1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.5.ln_1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.5.self_attention.self_attention.in_proj_weight: copying a param with shape torch.Size([1152, 384]) from checkpoint, the shape in current model is torch.Size([576, 192]).
	size mismatch for encoder.layers.5.self_attention.self_attention.in_proj_bias: copying a param with shape torch.Size([1152]) from checkpoint, the shape in current model is torch.Size([576]).
	size mismatch for encoder.layers.5.self_attention.self_attention.out_proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([192, 192]).
	size mismatch for encoder.layers.5.self_attention.self_attention.out_proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.5.ln_2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.5.ln_2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.5.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 192]).
	size mismatch for encoder.layers.5.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.layers.5.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([192, 768]).
	size mismatch for encoder.layers.5.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.6.ln_1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.6.ln_1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.6.self_attention.self_attention.in_proj_weight: copying a param with shape torch.Size([1152, 384]) from checkpoint, the shape in current model is torch.Size([576, 192]).
	size mismatch for encoder.layers.6.self_attention.self_attention.in_proj_bias: copying a param with shape torch.Size([1152]) from checkpoint, the shape in current model is torch.Size([576]).
	size mismatch for encoder.layers.6.self_attention.self_attention.out_proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([192, 192]).
	size mismatch for encoder.layers.6.self_attention.self_attention.out_proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.6.ln_2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.6.ln_2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.6.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 192]).
	size mismatch for encoder.layers.6.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.layers.6.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([192, 768]).
	size mismatch for encoder.layers.6.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.7.ln_1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.7.ln_1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.7.self_attention.self_attention.in_proj_weight: copying a param with shape torch.Size([1152, 384]) from checkpoint, the shape in current model is torch.Size([576, 192]).
	size mismatch for encoder.layers.7.self_attention.self_attention.in_proj_bias: copying a param with shape torch.Size([1152]) from checkpoint, the shape in current model is torch.Size([576]).
	size mismatch for encoder.layers.7.self_attention.self_attention.out_proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([192, 192]).
	size mismatch for encoder.layers.7.self_attention.self_attention.out_proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.7.ln_2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.7.ln_2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.7.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 192]).
	size mismatch for encoder.layers.7.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.layers.7.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([192, 768]).
	size mismatch for encoder.layers.7.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.8.ln_1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.8.ln_1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.8.self_attention.self_attention.in_proj_weight: copying a param with shape torch.Size([1152, 384]) from checkpoint, the shape in current model is torch.Size([576, 192]).
	size mismatch for encoder.layers.8.self_attention.self_attention.in_proj_bias: copying a param with shape torch.Size([1152]) from checkpoint, the shape in current model is torch.Size([576]).
	size mismatch for encoder.layers.8.self_attention.self_attention.out_proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([192, 192]).
	size mismatch for encoder.layers.8.self_attention.self_attention.out_proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.8.ln_2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.8.ln_2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.8.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 192]).
	size mismatch for encoder.layers.8.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.layers.8.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([192, 768]).
	size mismatch for encoder.layers.8.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.9.ln_1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.9.ln_1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.9.self_attention.self_attention.in_proj_weight: copying a param with shape torch.Size([1152, 384]) from checkpoint, the shape in current model is torch.Size([576, 192]).
	size mismatch for encoder.layers.9.self_attention.self_attention.in_proj_bias: copying a param with shape torch.Size([1152]) from checkpoint, the shape in current model is torch.Size([576]).
	size mismatch for encoder.layers.9.self_attention.self_attention.out_proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([192, 192]).
	size mismatch for encoder.layers.9.self_attention.self_attention.out_proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.9.ln_2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.9.ln_2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.9.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 192]).
	size mismatch for encoder.layers.9.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.layers.9.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([192, 768]).
	size mismatch for encoder.layers.9.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.10.ln_1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.10.ln_1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.10.self_attention.self_attention.in_proj_weight: copying a param with shape torch.Size([1152, 384]) from checkpoint, the shape in current model is torch.Size([576, 192]).
	size mismatch for encoder.layers.10.self_attention.self_attention.in_proj_bias: copying a param with shape torch.Size([1152]) from checkpoint, the shape in current model is torch.Size([576]).
	size mismatch for encoder.layers.10.self_attention.self_attention.out_proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([192, 192]).
	size mismatch for encoder.layers.10.self_attention.self_attention.out_proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.10.ln_2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.10.ln_2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.10.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 192]).
	size mismatch for encoder.layers.10.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.layers.10.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([192, 768]).
	size mismatch for encoder.layers.10.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.11.ln_1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.11.ln_1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.11.self_attention.self_attention.in_proj_weight: copying a param with shape torch.Size([1152, 384]) from checkpoint, the shape in current model is torch.Size([576, 192]).
	size mismatch for encoder.layers.11.self_attention.self_attention.in_proj_bias: copying a param with shape torch.Size([1152]) from checkpoint, the shape in current model is torch.Size([576]).
	size mismatch for encoder.layers.11.self_attention.self_attention.out_proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([192, 192]).
	size mismatch for encoder.layers.11.self_attention.self_attention.out_proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.11.ln_2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.11.ln_2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.layers.11.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 192]).
	size mismatch for encoder.layers.11.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.layers.11.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([192, 768]).
	size mismatch for encoder.layers.11.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.ln.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for encoder.ln.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for head.weight: copying a param with shape torch.Size([10, 384]) from checkpoint, the shape in current model is torch.Size([10, 192]).

You can also provide local paths to torch or timm pretrained weights.

#### Load state from checkpoint 

During training checkpoints are saved automatically, you can load them like so:

In [8]:
from AdaViT.utils.utils import load_state

#model, optimizer, epoch, model_args, noise_args = load_state(path='../runs/2024-01-24-10-30-25/checkpoints/epoch_000.pth', model=None)

In [10]:
#model_args

#### Load pretrained Vit params into other Vit 

You can also initialize a vit with a different architecture from pre-trained weights, as long as dimensions match. If the model has different head dim or additional paramters, they will be randomly initalized.

In [7]:
from AdaViT.models.rankvit import RankVisionTransformer

rankvit = RankVisionTransformer(
    image_size = 160, 
    patch_size  = 8, 
    num_classes  = 10, 
    hidden_dim = 256, 
    mlp_dim = 768, 
    num_layers = 4, 
    num_heads = 4
    )

rankvit, model_args, *_ = load_state(path='../runs/2024-01-24-10-30-25/checkpoints/epoch_100.pth', model=rankvit)
rankvit

FileNotFoundError: [Errno 2] No such file or directory: '../runs/2024-01-24-10-30-25/checkpoints/epoch_100.pth'