In [1]:
# tutorial found at https://huggingface.co/docs/transformers/en/model_doc/vit

In [1]:
import torch
from transformers import pipeline
from transformers import AutoImageProcessor, ViTConfig, ViTModel, ViTForImageClassification
from datasets import load_dataset
import matplotlib.pyplot as plt
import numpy as np
pipeline = pipeline(
    task="image-classification",
    model="google/vit-base-patch16-224",
    torch_dtype=torch.float16,
    device=0
)

Device set to use mps:0


In [2]:
# Initializing a ViT vit-base-patch16-224 style configuration
configuration = ViTConfig()

# Initializing a model (with random weights) from the vit-base-patch16-224 style configuration
model = ViTModel(configuration)

# Accessing the model configuration
configuration = model.config

In [3]:

dataset = load_dataset("huggingface/cats-image", trust_remote_code=True)
image = dataset["test"]["image"][0]

image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") # resizes and crops the img just picking up the center # FIXME custom the parameters of cropping and of normalization
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")

inputs = image_processor(image, return_tensors="pt")
with torch.no_grad():
    logits = model(**inputs).logits

# model predicts one of the 1000 ImageNet classes
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])

Egyptian cat


In [4]:
print(model)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermed

In [5]:
num_patches = 14**2
layer_dim = 768
out_dim = (num_patches+1) *layer_dim
encoder_blocks = 12
rand_idx = []
for block in range(encoder_blocks):
    rand_idx.append(
        np.random.choice(np.arange(out_dim), size=out_dim // 100, replace=False)
    )

In [6]:
def wrapper_hook(layer, rand_idx): # FIXME add randidx
    def hook_func(module, input, output):
        out = output.detach().half().reshape(-1)
        out = out[rand_idx]
        feats[layer].append(
            out
        )  # half makes it become float16, reshape(-1) vectorizes it

    return hook_func
    

In [7]:
# FIXME add the hooks appropriately
print(model.vit.encoder.layer[1].output) # or .encoder .classifier

ViTOutput(
  (dense): Linear(in_features=3072, out_features=768, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
)


In [8]:
print(model.vit.encoder.layer[0])

ViTLayer(
  (attention): ViTAttention(
    (attention): ViTSelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
    )
    (output): ViTSelfOutput(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
  )
  (intermediate): ViTIntermediate(
    (dense): Linear(in_features=768, out_features=3072, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): ViTOutput(
    (dense): Linear(in_features=3072, out_features=768, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)


In [9]:
hook_handle = []
# here we are hooking the output of of the 12 transformers blocks 
# they all end with an MLP 768>3072>768 -> importantly, it processes each of the patches indepently and identically
# so the output will be (batch_size, num_tokens, hidden_dim). 
# Leaving aside the batch_size, the tokens (embeddings for the patches) are on the rows, they are (224^2 / 16^2) +1 = 196+1 (the +1 is given by the classification token, a summary of the img)
for block_idx in range(encoder_blocks):
    hook_handle.append(
        model.vit.encoder.layer[block_idx].output.register_forward_hook(
                wrapper_hook(block_idx, rand_idx[block_idx])
            )
        )


In [10]:
#for h in hook_handle:
#    h.remove()

In [11]:
feats = {i: [] for i in range(encoder_blocks)}
with torch.no_grad():
    logits = model(**inputs).logits

# model predicts one of the 1000 ImageNet classes
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])

Egyptian cat


In [12]:
torch.Tensor.size(feats[2][0])

torch.Size([1512])

In [26]:
rand_idx[2]

array([  2914,   6080, 119734, ...,  45219,  27286, 144591])