<a href="https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/Perceiver/Fine_tune_the_Perceiver_for_image_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this notebook, we are going to show how to fine-tune the Perceiver for image classification.

For more info regarding the Perceiver, I refer to:

* the Transformers docs: https://huggingface.co/docs/transformers/model_doc/perceiver
* the blog post: https://huggingface.co/blog/perceiver

## Set-up environment

We first install HuggingFace Transformers & Datasets.

In [1]:
!pip install -q transformers datasets

[K     |████████████████████████████████| 3.4 MB 6.4 MB/s 
[K     |████████████████████████████████| 306 kB 52.0 MB/s 
[K     |████████████████████████████████| 596 kB 39.2 MB/s 
[K     |████████████████████████████████| 895 kB 55.4 MB/s 
[K     |████████████████████████████████| 61 kB 331 kB/s 
[K     |████████████████████████████████| 3.3 MB 51.8 MB/s 
[K     |████████████████████████████████| 1.1 MB 49.6 MB/s 
[K     |████████████████████████████████| 132 kB 50.4 MB/s 
[K     |████████████████████████████████| 243 kB 65.3 MB/s 
[K     |████████████████████████████████| 192 kB 70.7 MB/s 
[K     |████████████████████████████████| 271 kB 68.7 MB/s 
[K     |████████████████████████████████| 160 kB 72.1 MB/s 
[?25h

## Load data

Here we load a small portion of the CIFAR-10 dataset, for demonstration purposes.

In [2]:
from datasets import load_dataset

# load cifar10 (only small portion for demonstration purposes) 
train_ds, test_ds = load_dataset('cifar10', split=['train[:500]', 'test[:200]'])
# split up training into training + validation
splits = train_ds.train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']

Downloading:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/799 [00:00<?, ?B/s]

Downloading and preparing dataset cifar10/plain_text (download: 162.60 MiB, generated: 130.30 MiB, post-processed: Unknown size, total: 292.90 MiB) to /root/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/7cc98527296ba6b416e709858ac23013f8e6b0201b1d8bb088fb4071379806ca...


Downloading:   0%|          | 0.00/170M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset cifar10 downloaded and prepared to /root/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/7cc98527296ba6b416e709858ac23013f8e6b0201b1d8bb088fb4071379806ca. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

We'll define the id2label and label2id dictionaries, as these will be useful when doing inference.

In [3]:
id2label = {idx:label for idx,label in enumerate(train_ds.features['label'].names)}
label2id = {label:idx for idx, label in id2label.items()}
print(id2label)
print(label2id)

{0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}
{'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}


We can prepare the data for the model using the feature extractor.

Note that this feature extractor is fairly basic: it will just do center cropping + resizing + normalizing of the color channels.

One should actually add several data augmentations (available in libraries like [torchvision](https://pytorch.org/vision/stable/transforms.html) and [albumentations](https://albumentations.ai/) to achieve greater results. I refer to my [ViT notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_the_%F0%9F%A4%97_Trainer.ipynb) for an example.

In [4]:
from transformers import PerceiverFeatureExtractor

feature_extractor = PerceiverFeatureExtractor()

Note that HuggingFace Datasets has an Image feature, meaning that every image is a PIL (Pillow) image by default. The feature extractor will turn each Pillow image into a PyTorch tensor of shape (3, 224, 224).

Note that Apache Arrow (which HuggingFace Datasets uses as a back-end) doesn't know PyTorch Tensors, but we can escape it by using the `set_transform` method on the Dataset, which allows to only prepare images when we need them (i.e. on-the-fly). This is awesome as it saves memory! Refer to the [docs](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.set_transform) for more information.

In [10]:
import numpy as np

def preprocess_images(examples):
    examples['pixel_values'] = feature_extractor(examples['img'], return_tensors="pt").pixel_values
    return examples

In [11]:
# Set the transforms
train_ds.set_transform(preprocess_images)
val_ds.set_transform(preprocess_images)
test_ds.set_transform(preprocess_images)

We can now load preprocessed images (on-the-fly) as follows:

In [12]:
train_ds[:2]

{'img': [<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32 at 0x7FDC654E27D0>,
  <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32 at 0x7FDC654E2110>],
 'label': [7, 5],
 'pixel_values': tensor([[[[-0.1828, -0.1828, -0.1999,  ..., -0.5253, -0.5253, -0.5253],
           [-0.1828, -0.1828, -0.1999,  ..., -0.5253, -0.5253, -0.5253],
           [-0.1999, -0.1999, -0.2171,  ..., -0.5253, -0.5253, -0.5253],
           ...,
           [ 0.5022,  0.5022,  0.5022,  ...,  0.0912,  0.0741,  0.0741],
           [ 0.5022,  0.5022,  0.5022,  ...,  0.1254,  0.1083,  0.1083],
           [ 0.5022,  0.5022,  0.5022,  ...,  0.1426,  0.1254,  0.1254]],
 
          [[-0.1800, -0.1800, -0.2150,  ..., -0.6176, -0.6176, -0.6176],
           [-0.1975, -0.1975, -0.2150,  ..., -0.6176, -0.6176, -0.6176],
           [-0.1975, -0.1975, -0.2150,  ..., -0.6176, -0.6176, -0.6176],
           ...,
           [ 0.5903,  0.5903,  0.5903,  ...,  0.2227,  0.2227,  0.2227],
           [ 0.5903,  0.5903,  

It's very easy to create corresponding PyTorch DataLoaders, like so:

In [17]:
from torch.utils.data import DataLoader
import torch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

train_batch_size = 2
eval_batch_size = 2

train_dataloader = DataLoader(train_ds, shuffle=True, collate_fn=collate_fn, batch_size=train_batch_size)
val_dataloader = DataLoader(val_ds, collate_fn=collate_fn, batch_size=eval_batch_size)
test_dataloader = DataLoader(test_ds, collate_fn=collate_fn, batch_size=eval_batch_size)

We can verify our data a bit:

In [18]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k, v.shape)

pixel_values torch.Size([2, 3, 224, 224])
labels torch.Size([2])


Some more verification:

In [20]:
assert batch['pixel_values'].shape == (train_batch_size, 3, 224, 224)
assert batch['labels'].shape == (train_batch_size,)

In [21]:
next(iter(val_dataloader))['pixel_values'].shape

torch.Size([2, 3, 224, 224])

## Define model

Here we only replace the final projection layer of the decoder (`PerceiverClassificationDecoder`) of the checkpoint that was trained on ImageNet. This means that we will use the same (learned) output queries as before, hence the cross-attention operation will give the same output. However, the final projection layer has 1000 output neurons during pre-training, while we only have 10.

NOTE: note that the Perceiver has 3 variants for image classification:
* PerceiverForImageClassificationLearned
* PerceiverForImageClassificationFourier
* PerceiverForImageClassificationConvProcessing.

Here I'm using the first one, which adds learned 1D position embeddings to the pixel values. Note that the best results will be obtained with the latter.

For in-depth understanding on how the Perceiver works, I refer to my [blog post](https://huggingface.co/blog/perceiver).

We can use the handy `ignore_mismatched_sizes` to replace the head. We also set the `id2label` and `label2id` mappings we defined earlier (which will be handy when doing inference).

In [22]:
from transformers import PerceiverForImageClassificationLearned 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = PerceiverForImageClassificationLearned.from_pretrained("deepmind/vision-perceiver-learned",
                                                               num_labels=10,
                                                               id2label=id2label,
                                                               label2id=label2id,
                                                               ignore_mismatched_sizes=True)
model.to(device)

Downloading:   0%|          | 0.00/68.4k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/238M [00:00<?, ?B/s]

Some weights of PerceiverForImageClassificationLearned were not initialized from the model checkpoint at deepmind/vision-perceiver-learned and are newly initialized because the shapes did not match:
- perceiver.decoder.decoder.final_layer.weight: found shape torch.Size([1000, 1024]) in the checkpoint and torch.Size([10, 1024]) in the model instantiated
- perceiver.decoder.decoder.final_layer.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


PerceiverForImageClassificationLearned(
  (perceiver): PerceiverModel(
    (input_preprocessor): PerceiverImagePreprocessor(
      (convnet_1x1): Conv2d(3, 256, kernel_size=(1, 1), stride=(1, 1))
      (position_embeddings): PerceiverTrainablePositionEncoding()
      (positions_projection): Linear(in_features=256, out_features=256, bias=True)
      (conv_after_patches): Identity()
    )
    (embeddings): PerceiverEmbeddings()
    (encoder): PerceiverEncoder(
      (cross_attention): PerceiverLayer(
        (attention): PerceiverAttention(
          (self): PerceiverSelfAttention(
            (layernorm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (layernorm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (query): Linear(in_features=1024, out_features=512, bias=True)
            (key): Linear(in_features=512, out_features=512, bias=True)
            (value): Linear(in_features=512, out_features=512, bias=True)
            (dropout): Dropout(

## Train the model

Here we train the model using native PyTorch.

In [25]:
from transformers import AdamW
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

optimizer = AdamW(model.parameters(), lr=5e-5)

model.train()
for epoch in range(2):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    for batch in tqdm(train_dataloader):
         # get the inputs; 
         inputs = batch["pixel_values"].to(device)
         labels = batch["labels"].to(device)

         # zero the parameter gradients
         optimizer.zero_grad()

         # forward + backward + optimize
         outputs = model(inputs=inputs, labels=labels)
         loss = outputs.loss
         loss.backward()
         optimizer.step()

         # evaluate
         predictions = outputs.logits.argmax(-1).cpu().detach().numpy()
         accuracy = accuracy_score(y_true=batch["labels"].numpy(), y_pred=predictions)
         print(f"Loss: {loss.item()}, Accuracy: {accuracy}")

Epoch: 0


  0%|          | 0/225 [00:00<?, ?it/s]

Loss: 2.6121115684509277, Accuracy: 0.0
Loss: 2.087411642074585, Accuracy: 0.5
Loss: 2.516732692718506, Accuracy: 0.0
Loss: 2.915447950363159, Accuracy: 0.0
Loss: 2.631737232208252, Accuracy: 0.0
Loss: 2.2259626388549805, Accuracy: 0.5
Loss: 2.619555950164795, Accuracy: 0.0
Loss: 2.2871932983398438, Accuracy: 0.0
Loss: 2.1763486862182617, Accuracy: 0.0
Loss: 2.0308685302734375, Accuracy: 0.5
Loss: 2.330275297164917, Accuracy: 0.0
Loss: 1.7330862283706665, Accuracy: 1.0
Loss: 2.094666004180908, Accuracy: 0.5
Loss: 2.516792058944702, Accuracy: 0.0
Loss: 1.8952770233154297, Accuracy: 0.5
Loss: 2.5204763412475586, Accuracy: 0.0
Loss: 1.9076743125915527, Accuracy: 0.5
Loss: 1.8706510066986084, Accuracy: 0.5
Loss: 1.2029504776000977, Accuracy: 1.0
Loss: 2.11360764503479, Accuracy: 0.0
Loss: 1.1523082256317139, Accuracy: 0.5
Loss: 2.8178200721740723, Accuracy: 0.0
Loss: 2.510197401046753, Accuracy: 0.0
Loss: 2.008117198944092, Accuracy: 0.0
Loss: 2.6826279163360596, Accuracy: 0.0
Loss: 1.6611

  0%|          | 0/225 [00:00<?, ?it/s]

Loss: 0.2511093020439148, Accuracy: 1.0
Loss: 0.20040327310562134, Accuracy: 1.0
Loss: 0.056868888437747955, Accuracy: 1.0
Loss: 0.0975453332066536, Accuracy: 1.0
Loss: 1.1249535083770752, Accuracy: 0.5
Loss: 0.04263799637556076, Accuracy: 1.0
Loss: 0.3301231861114502, Accuracy: 1.0
Loss: 2.047173023223877, Accuracy: 0.0
Loss: 0.3231906592845917, Accuracy: 1.0
Loss: 0.4727799892425537, Accuracy: 1.0
Loss: 0.06914356350898743, Accuracy: 1.0
Loss: 0.20520082116127014, Accuracy: 1.0
Loss: 0.0944804772734642, Accuracy: 1.0
Loss: 0.1908605694770813, Accuracy: 1.0
Loss: 0.06819569319486618, Accuracy: 1.0
Loss: 0.058507561683654785, Accuracy: 1.0
Loss: 0.030860979110002518, Accuracy: 1.0
Loss: 0.04331466555595398, Accuracy: 1.0
Loss: 1.1817548274993896, Accuracy: 0.5
Loss: 0.07851042598485947, Accuracy: 1.0
Loss: 0.01915682852268219, Accuracy: 1.0
Loss: 0.07248327136039734, Accuracy: 1.0
Loss: 0.1211576983332634, Accuracy: 1.0
Loss: 0.10418497025966644, Accuracy: 1.0
Loss: 0.04465044289827347

## Evaluate the model

Finally, we evaluate the model on the test set. We use the Datasets library to compute the accuracy.

On some runs, I got 78%, then 66%. Of course, one would need to train on the entire dataset to achieve great results.

In [27]:
from tqdm.notebook import tqdm
from datasets import load_metric

accuracy = load_metric("accuracy")

model.eval()
for batch in tqdm(val_dataloader):
      # get the inputs; 
      inputs = batch["pixel_values"].to(device)
      labels = batch["labels"].to(device)

      # forward pass
      outputs = model(inputs=inputs, labels=labels)
      logits = outputs.logits 
      predictions = logits.argmax(-1).cpu().detach().numpy()
      references = batch["labels"].numpy()
      accuracy.add_batch(predictions=predictions, references=references)

final_score = accuracy.compute()
print("Accuracy on test set:", final_score)

  0%|          | 0/25 [00:00<?, ?it/s]

Accuracy on test set: {'accuracy': 0.74}
