In [38]:
from transformers import ViTConfig, ViTModel

# Initializing a ViT vit-base-patch16-224 style configuration
configuration = ViTConfig(hidden_size = 384)
# Initializing a model (with random weights) from the vit-base-patch16-224 style configuration
hf_model = ViTModel(configuration)

In [45]:
128 ** 2

16384

In [43]:
print(hf_model)

ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0): ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=384, out_features=384, bias=True)
            (key): Linear(in_features=384, out_features=384, bias=True)
            (value): Linear(in_features=384, out_features=384, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=384, out_features=384, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=384, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
      

In [40]:
from vision_transformer import vit_base, vit_small

In [41]:
model = vit_small()

In [42]:
print(model)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): At

In [None]:
print(hf_model)

In [None]:
import utils
from vision_transformer import DINOHead

embed_dim = 768

teacher = utils.MultiCropWrapper(
    hf_model,
    DINOHead(embed_dim, 65536, False),
)

In [None]:
print(teacher)

In [None]:
from transformers import AutoImageProcessor
import torch
from datasets import load_dataset

dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
inputs = image_processor(image, return_tensors="pt")

with torch.no_grad():
    outputs = hf_model(**inputs)

In [None]:
outputs.pooler_output.shape

In [None]:
with torch.no_grad():
    outputs2 = model(inputs['pixel_values'])

In [35]:
outputs['pooler_output'].shape

torch.Size([1, 768])

In [29]:
inputs['pixel_values'][0]

tensor([[[ 0.1137,  0.1686,  0.1843,  ..., -0.1922, -0.1843, -0.1843],
         [ 0.1373,  0.1686,  0.1843,  ..., -0.1922, -0.1922, -0.2078],
         [ 0.1137,  0.1529,  0.1608,  ..., -0.2314, -0.2235, -0.2157],
         ...,
         [ 0.8353,  0.7882,  0.7333,  ...,  0.7020,  0.6471,  0.6157],
         [ 0.8275,  0.7961,  0.7725,  ...,  0.5843,  0.4667,  0.3961],
         [ 0.8196,  0.7569,  0.7569,  ...,  0.0745, -0.0510, -0.1922]],

        [[-0.8039, -0.8118, -0.8118,  ..., -0.8902, -0.8902, -0.8980],
         [-0.7882, -0.7882, -0.7882,  ..., -0.8745, -0.8745, -0.8824],
         [-0.8118, -0.8039, -0.7882,  ..., -0.8902, -0.8902, -0.8902],
         ...,
         [-0.2706, -0.3176, -0.3647,  ..., -0.4275, -0.4588, -0.4824],
         [-0.2706, -0.2941, -0.3412,  ..., -0.4824, -0.5451, -0.5765],
         [-0.2784, -0.3412, -0.3490,  ..., -0.7333, -0.7804, -0.8353]],

        [[-0.5451, -0.4667, -0.4824,  ..., -0.7412, -0.6941, -0.7176],
         [-0.5529, -0.5137, -0.4902,  ..., -0