In [8]:
from utils.ggem import GGeM
from transformers import AutoModel, AutoImageProcessor
import torchvision.transforms
import PIL.Image

## The number of groups G is a pre-defined hyper-parameter, where we suggest using the number of heads in ViT (i.e., 6 for ViT-S, 12 for ViT-B, and 16 for ViT-L) as G.

## https://theaisummer.com/static/156f6be8c232bb0b03d4793b2be2fbad/57dc1/vit-models-description-table.png


In [9]:
ggem = GGeM(groups=16, eps=1e-6)

In [10]:
model_ckpt = "google/vit-large-patch16-224"
processor = AutoImageProcessor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)

Some weights of the model checkpoint at google/vit-large-patch16-224 were not used when initializing ViTModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [11]:
transformation_chain = torchvision.transforms.Compose(
    [
        # We first resize the input image to 256x256, and then we take center crop.
        torchvision.transforms.Resize(int((256 / 224) * processor.size["height"])),
        torchvision.transforms.CenterCrop(processor.size["height"]),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
    ]
)

In [13]:
image = PIL.Image.open("/home/augustinas/google-landmark/train/0/0/0/000ab4c3e0183bfc.jpg")

In [14]:
img = transformation_chain(image)

In [18]:
img = img.unsqueeze(0)

In [20]:
out = model(img)

In [21]:
out

BaseModelOutputWithPooling(last_hidden_state=tensor([[[ 0.9761,  0.6478, -0.0977,  ..., -0.3124, -0.3076,  1.7004],
         [ 0.6479, -0.5387,  0.8248,  ...,  1.9311, -0.4832,  0.3432],
         [ 0.0613,  0.6355,  1.0129,  ...,  0.7574,  0.2355,  0.1633],
         ...,
         [ 0.0604,  0.9856, -1.5708,  ..., -0.8957, -0.0799,  1.5861],
         [-0.3113,  1.2828, -1.1609,  ..., -1.0621,  0.1105,  1.3995],
         [-0.0850,  1.1988, -1.2745,  ..., -1.6373, -0.1932,  1.5015]]],
       grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[ 0.2543, -0.1908,  0.1511,  ..., -0.3221, -0.4381,  0.1200]],
       grad_fn=<TanhBackward0>), hidden_states=None, attentions=None)

In [22]:
out = out['pooler_output']

In [23]:
out.shape

torch.Size([1, 1024])

In [24]:
model.pooler

ViTPooler(
  (dense): Linear(in_features=1024, out_features=1024, bias=True)
  (activation): Tanh()
)

In [25]:
model.pooler = ggem

In [26]:
model.pooler

GGeM()

In [27]:
out = model(img)

In [28]:
out

BaseModelOutputWithPooling(last_hidden_state=tensor([[[ 0.9761,  0.6478, -0.0977,  ..., -0.3124, -0.3076,  1.7004],
         [ 0.6479, -0.5387,  0.8248,  ...,  1.9311, -0.4832,  0.3432],
         [ 0.0613,  0.6355,  1.0129,  ...,  0.7574,  0.2355,  0.1633],
         ...,
         [ 0.0604,  0.9856, -1.5708,  ..., -0.8957, -0.0799,  1.5861],
         [-0.3113,  1.2828, -1.1609,  ..., -1.0621,  0.1105,  1.3995],
         [-0.0850,  1.1988, -1.2745,  ..., -1.6373, -0.1932,  1.5015]]],
       grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[1.0985, 1.3772, 1.3914,  ..., 1.7593, 0.9243, 1.8296]],
       grad_fn=<ReshapeAliasBackward0>), hidden_states=None, attentions=None)

In [30]:
out = out['pooler_output']

In [31]:
out = out.shape

In [32]:
out

torch.Size([1, 1024])