In [10]:
from datasets import load_dataset
from transformers import ViTImageProcessor
from einops import rearrange

In [2]:
val_dataset = load_dataset("/data/jc/datasets/cifar-10", split="test", streaming=True)

processor = ViTImageProcessor.from_pretrained('../weights/vit-base-patch16-224-in21k-finetuned-cifar10/')

def preprocess_function(item):
    # Resize the input image to the model's size
    inputs = processor(images=item["img"], return_tensors="pt")
    inputs["labels"] = item["label"]
    return inputs 

val_dataset = val_dataset.map(preprocess_function, remove_columns=["img"], batched=True)

In [3]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained('../weights/vit-base-patch16-224-in21k-finetuned-cifar10')

In [4]:
import torch 

inputs = torch.randn(1, 16, 768)

In [5]:
query_fn = model.vit.encoder.layer[0].attention.attention.query
key_fn = model.vit.encoder.layer[0].attention.attention.key

In [6]:
query_fn.weight.shape

torch.Size([768, 768])

In [7]:
query = query_fn(inputs)
query_manually = torch.einsum('b n d, m d -> b n m', inputs, query_fn.weight) + query_fn.bias
torch.allclose(query, query_manually, atol=1e-6)

True

In [9]:
query_head = model.vit.encoder.layer[0].attention.attention.transpose_for_scores(query)
print(query_head.shape)

torch.Size([1, 12, 16, 64])


In [20]:
query_manually_weight = rearrange(query_fn.weight, '(h o) d -> h o d', h=12)
query_manually_bias = rearrange(query_fn.bias, '(h o) -> h 1 o', h=12)
query_manually_head = torch.einsum('b n d, h o d -> b h n o', inputs, query_manually_weight) + query_manually_bias

In [21]:
torch.allclose(query_head, query_manually_head, atol=1e-6)

True