<a href="https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/Perceiver/Fine_tune_the_Perceiver_for_image_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this notebook, we are going to show how to fine-tune the Perceiver for image classification.

For more info regarding the Perceiver, I refer to:

* the Transformers docs: https://huggingface.co/docs/transformers/model_doc/perceiver
* the blog post: https://huggingface.co/blog/perceiver

## Set-up environment

We first install HuggingFace Transformers & datasets.

In [1]:
!pip install -q transformers datasets

[K     |████████████████████████████████| 3.4 MB 8.3 MB/s 
[K     |████████████████████████████████| 298 kB 69.1 MB/s 
[K     |████████████████████████████████| 596 kB 66.0 MB/s 
[K     |████████████████████████████████| 61 kB 674 kB/s 
[K     |████████████████████████████████| 3.3 MB 59.6 MB/s 
[K     |████████████████████████████████| 895 kB 53.3 MB/s 
[K     |████████████████████████████████| 132 kB 68.6 MB/s 
[K     |████████████████████████████████| 243 kB 69.7 MB/s 
[K     |████████████████████████████████| 1.1 MB 67.8 MB/s 
[K     |████████████████████████████████| 192 kB 71.3 MB/s 
[K     |████████████████████████████████| 271 kB 67.7 MB/s 
[K     |████████████████████████████████| 160 kB 66.3 MB/s 
[?25h

## Load data

Here we load a small portion of the CIFAR-10 dataset, for demonstration purposes.

In [2]:
from datasets import load_dataset

# load cifar10 (only small portion for demonstration purposes) 
train_ds, test_ds = load_dataset('cifar10', split=['train[:500]', 'test[:200]'])
# split up training into training + validation
splits = train_ds.train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']

Downloading:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/773 [00:00<?, ?B/s]

Downloading and preparing dataset cifar10/plain_text (download: 162.60 MiB, generated: 418.17 MiB, post-processed: Unknown size, total: 580.77 MiB) to /root/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/5da9550526dac91579c0df95a56466f78e62cc6ea1ccffd17f71f2e64aa86b5e...


Downloading:   0%|          | 0.00/170M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset cifar10 downloaded and prepared to /root/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/5da9550526dac91579c0df95a56466f78e62cc6ea1ccffd17f71f2e64aa86b5e. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

We'll define the id2label and label2id dictionaries, as these will be useful when doing inference.

In [3]:
id2label = {idx:label for idx,label in enumerate(train_ds.features['label'].names)}
label2id = {label:idx for idx, label in id2label.items()}
print(id2label)
print(label2id)

{0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}
{'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}


We can prepare the data for the model using the feature extractor.

Note that this feature extractor is fairly basic: it will just do center cropping + resizing + normalizing of the color channels.

One should actually add several data augmentations (available in libraries like [torchvision](https://pytorch.org/vision/stable/transforms.html) and [albumentations](https://albumentations.ai/) to achieve greater results.

In [4]:
from transformers import PerceiverFeatureExtractor

feature_extractor = PerceiverFeatureExtractor()

In [5]:
import numpy as np

def preprocess_images(examples):
    # get batch of images
    images = examples['img']
    # convert to list of NumPy arrays of shape (C, H, W)
    images = [np.array(image, dtype=np.uint8) for image in images]
    images = [np.moveaxis(image, source=-1, destination=0) for image in images]
    # preprocess and add pixel_values
    inputs = feature_extractor(images=images)
    examples['pixel_values'] = inputs['pixel_values']

    return examples

In [6]:
from datasets import Features, ClassLabel, Array3D

# we need to define the features ourselves as both the img and pixel_values have a 3D shape 
features = Features({
    'label': ClassLabel(names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']),
    'img': Array3D(dtype="int64", shape=(3,32,32)),
    'pixel_values': Array3D(dtype="float32", shape=(3, 224, 224)),
})

preprocessed_train_ds = train_ds.map(preprocess_images, batched=True, features=features)
preprocessed_val_ds = val_ds.map(preprocess_images, batched=True, features=features)
preprocessed_test_ds = test_ds.map(preprocess_images, batched=True, features=features)

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

Finally, we turn everything into PyTorch tensors.

In [7]:
# set format to PyTorch
preprocessed_train_ds.set_format('torch', columns=['pixel_values', 'label'])
preprocessed_val_ds.set_format('torch', columns=['pixel_values', 'label'])
preprocessed_test_ds.set_format('torch', columns=['pixel_values', 'label'])

We can verify our data a bit:

In [8]:
preprocessed_train_ds[0].keys()

dict_keys(['label', 'pixel_values'])

Next, we create corresponding PyTorch dataloaders.

In [9]:
import torch

# create dataloaders
train_batch_size = 2
eval_batch_size = 2
train_dataloader = torch.utils.data.DataLoader(preprocessed_train_ds, batch_size=train_batch_size)
val_dataloader = torch.utils.data.DataLoader(preprocessed_val_ds, batch_size=eval_batch_size)
test_dataloader = torch.utils.data.DataLoader(preprocessed_test_ds, batch_size=eval_batch_size)
batch = next(iter(train_dataloader))

Some more verification:

In [10]:
assert batch['pixel_values'].shape == (train_batch_size, 3, 224, 224)
assert batch['label'].shape == (train_batch_size,)

In [11]:
next(iter(val_dataloader))['pixel_values'].shape

torch.Size([2, 3, 224, 224])

## Define model

Here we only replace the final projection layer of the decoder (`PerceiverClassificationDecoder`) of the checkpoint that was trained on ImageNet. This means that we will use the same (learned) output queries as before, hence the cross-attention operation will give the same output. However, the final projection layer has 1000 output neurons during pre-training, while we only have 10.

NOTE: note that the Perceiver has 3 variants for image classification:
* PerceiverForImageClassificationLearned
* PerceiverForImageClassificationFourier
* PerceiverForImageClassificationConvProcessing.

Here I'm using the first one, which adds learned 1D position embeddings to the pixel values. Note that the best results will be obtained with the latter.

For in-depth understanding on how the Perceiver works, I refer to my [blog post](https://huggingface.co/blog/perceiver).

We can use the handy `ignore_mismatched_sizes` to replace the head. We also set the `id2label` and `label2id` mappings we defined earlier (which will be handy when doing inference).

In [12]:
from transformers import PerceiverForImageClassificationLearned 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = PerceiverForImageClassificationLearned.from_pretrained("deepmind/vision-perceiver-learned",
                                                               num_labels=10,
                                                               id2label=id2label,
                                                               label2id=label2id,
                                                               ignore_mismatched_sizes=True)
model.to(device)

Downloading:   0%|          | 0.00/68.4k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/238M [00:00<?, ?B/s]

Some weights of PerceiverForImageClassificationLearned were not initialized from the model checkpoint at deepmind/vision-perceiver-learned and are newly initialized because the shapes did not match:
- perceiver.decoder.decoder.final_layer.weight: found shape torch.Size([1000, 1024]) in the checkpoint and torch.Size([10, 1024]) in the model instantiated
- perceiver.decoder.decoder.final_layer.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


PerceiverForImageClassificationLearned(
  (perceiver): PerceiverModel(
    (input_preprocessor): PerceiverImagePreprocessor(
      (convnet_1x1): Conv2d(3, 256, kernel_size=(1, 1), stride=(1, 1))
      (position_embeddings): PerceiverTrainablePositionEncoding()
      (positions_projection): Linear(in_features=256, out_features=256, bias=True)
      (conv_after_patches): Identity()
    )
    (embeddings): PerceiverEmbeddings()
    (encoder): PerceiverEncoder(
      (cross_attention): PerceiverLayer(
        (attention): PerceiverAttention(
          (self): PerceiverSelfAttention(
            (layernorm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (layernorm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (query): Linear(in_features=1024, out_features=512, bias=True)
            (key): Linear(in_features=512, out_features=512, bias=True)
            (value): Linear(in_features=512, out_features=512, bias=True)
            (dropout): Dropout(

## Train the model

Here we train the model using native PyTorch.

In [13]:
from transformers import AdamW
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

optimizer = AdamW(model.parameters(), lr=5e-5)

model.train()
for epoch in range(2):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    for batch in tqdm(train_dataloader):
         # get the inputs; 
         inputs = batch["pixel_values"].to(device)
         labels = batch["label"].to(device)

         # zero the parameter gradients
         optimizer.zero_grad()

         # forward + backward + optimize
         outputs = model(inputs=inputs, labels=labels)
         loss = outputs.loss
         loss.backward()
         optimizer.step()

         # evaluate
         predictions = outputs.logits.argmax(-1).cpu().detach().numpy()
         accuracy = accuracy_score(y_true=batch["label"].numpy(), y_pred=predictions)
         print(f"Loss: {loss.item()}, Accuracy: {accuracy}")

Epoch: 0


  0%|          | 0/225 [00:00<?, ?it/s]

Loss: 2.449394464492798, Accuracy: 0.5
Loss: 2.676224708557129, Accuracy: 0.0
Loss: 2.3117918968200684, Accuracy: 0.0
Loss: 2.7640700340270996, Accuracy: 0.0
Loss: 2.153238296508789, Accuracy: 0.5
Loss: 1.7548975944519043, Accuracy: 0.5
Loss: 2.592646598815918, Accuracy: 0.0
Loss: 2.4706153869628906, Accuracy: 0.0
Loss: 2.0476417541503906, Accuracy: 0.0
Loss: 2.4134531021118164, Accuracy: 0.0
Loss: 2.3316354751586914, Accuracy: 0.0
Loss: 1.7927266359329224, Accuracy: 1.0
Loss: 2.0628437995910645, Accuracy: 0.5
Loss: 1.8099902868270874, Accuracy: 0.0
Loss: 1.8924148082733154, Accuracy: 0.5
Loss: 2.362542152404785, Accuracy: 0.0
Loss: 1.4972447156906128, Accuracy: 0.5
Loss: 2.307976484298706, Accuracy: 0.0
Loss: 2.3332736492156982, Accuracy: 0.0
Loss: 3.058933734893799, Accuracy: 0.0
Loss: 2.3993752002716064, Accuracy: 0.0
Loss: 2.2155649662017822, Accuracy: 0.5
Loss: 2.199784755706787, Accuracy: 0.5
Loss: 2.624464988708496, Accuracy: 0.0
Loss: 1.8652068376541138, Accuracy: 0.5
Loss: 1.9

  0%|          | 0/225 [00:00<?, ?it/s]

Loss: 0.024263061583042145, Accuracy: 1.0
Loss: 0.5569589734077454, Accuracy: 1.0
Loss: 0.011131217703223228, Accuracy: 1.0
Loss: 0.14349067211151123, Accuracy: 1.0
Loss: 1.9892107248306274, Accuracy: 0.0
Loss: 0.25663527846336365, Accuracy: 1.0
Loss: 0.05386708304286003, Accuracy: 1.0
Loss: 0.42413264513015747, Accuracy: 1.0
Loss: 0.5003420114517212, Accuracy: 1.0
Loss: 0.5049768686294556, Accuracy: 1.0
Loss: 0.48787593841552734, Accuracy: 1.0
Loss: 0.11221908032894135, Accuracy: 1.0
Loss: 0.31401270627975464, Accuracy: 1.0
Loss: 0.07242441177368164, Accuracy: 1.0
Loss: 0.4570055603981018, Accuracy: 1.0
Loss: 5.9248948097229, Accuracy: 0.0
Loss: 0.034650105983018875, Accuracy: 1.0
Loss: 3.429060459136963, Accuracy: 0.5
Loss: 1.5643033981323242, Accuracy: 0.5
Loss: 0.6744573712348938, Accuracy: 0.5
Loss: 0.007175706792622805, Accuracy: 1.0
Loss: 0.6728950142860413, Accuracy: 1.0
Loss: 2.1263370513916016, Accuracy: 0.5
Loss: 1.5938011407852173, Accuracy: 0.5
Loss: 0.9420684576034546, Ac

## Evaluate the model

Finally, we evaluate the model on the test set. We use the Datasets library to compute the accuracy.

On some runs, I got 78%, then 66%. Of course, one would need to train on the entire dataset to achieve great results.

In [14]:
from tqdm.notebook import tqdm
from datasets import load_metric

accuracy = load_metric("accuracy")

model.eval()
for batch in tqdm(val_dataloader):
      # get the inputs; 
      inputs = batch["pixel_values"].to(device)
      labels = batch["label"].to(device)

      # forward pass
      outputs = model(inputs=inputs, labels=labels)
      logits = outputs.logits 
      predictions = logits.argmax(-1).cpu().detach().numpy()
      references = batch["label"].numpy()
      accuracy.add_batch(predictions=predictions, references=references)

final_score = accuracy.compute()
print("Accuracy on test set:", final_score)

Downloading:   0%|          | 0.00/1.42k [00:00<?, ?B/s]

  0%|          | 0/25 [00:00<?, ?it/s]

Accuracy on test set: {'accuracy': 0.66}
