<a href="https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/Perceiver/Fine_tune_the_Perceiver_for_image_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this notebook, we are going to show how to fine-tune the Perceiver for image classification.

## Set-up environment

We first install HuggingFace Transformers & datasets.

In [None]:
!pip install -q git+https://github.com/huggingface/transformers.git

rm: cannot remove 'transformers': No such file or directory
Cloning into 'transformers'...
remote: Enumerating objects: 95353, done.[K
remote: Counting objects: 100% (20/20), done.[K
remote: Compressing objects: 100% (18/18), done.[K
remote: Total 95353 (delta 1), reused 20 (delta 1), pack-reused 95333[K
Receiving objects: 100% (95353/95353), 72.30 MiB | 14.35 MiB/s, done.
Resolving deltas: 100% (68650/68650), done.
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
[K     |

In [None]:
!pip install -q datasets

[K     |████████████████████████████████| 298 kB 4.2 MB/s 
[K     |████████████████████████████████| 243 kB 54.6 MB/s 
[K     |████████████████████████████████| 1.1 MB 51.6 MB/s 
[K     |████████████████████████████████| 132 kB 76.0 MB/s 
[K     |████████████████████████████████| 271 kB 57.6 MB/s 
[K     |████████████████████████████████| 192 kB 73.5 MB/s 
[K     |████████████████████████████████| 160 kB 71.0 MB/s 
[?25h

## Load data

In [None]:
from datasets import load_dataset

# load cifar10 (only small portion for demonstration purposes) 
train_ds, test_ds = load_dataset('cifar10', split=['train[:500]', 'test[:200]'])
# split up training into training + validation
splits = train_ds.train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']

Downloading:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/773 [00:00<?, ?B/s]

Downloading and preparing dataset cifar10/plain_text (download: 162.60 MiB, generated: 418.17 MiB, post-processed: Unknown size, total: 580.77 MiB) to /root/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/5da9550526dac91579c0df95a56466f78e62cc6ea1ccffd17f71f2e64aa86b5e...


Downloading:   0%|          | 0.00/170M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset cifar10 downloaded and prepared to /root/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/5da9550526dac91579c0df95a56466f78e62cc6ea1ccffd17f71f2e64aa86b5e. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
id2label = {idx:label for idx,label in enumerate(train_ds.features['label'].names)}
label2id = {label:idx for idx, label in id2label.items()}
print(id2label)
print(label2id)

{0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}
{'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}


In [None]:
from transformers import PerceiverFeatureExtractor

feature_extractor = PerceiverFeatureExtractor()

In [None]:
import numpy as np

def preprocess_images(examples):
    # get batch of images
    images = examples['img']
    # convert to list of NumPy arrays of shape (C, H, W)
    images = [np.array(image, dtype=np.uint8) for image in images]
    images = [np.moveaxis(image, source=-1, destination=0) for image in images]
    # preprocess and add pixel_values
    inputs = feature_extractor(images=images)
    examples['pixel_values'] = inputs['pixel_values']

    return examples

In [None]:
from datasets import Features, ClassLabel, Array3D

# we need to define the features ourselves as both the img and pixel_values have a 3D shape 
features = Features({
    'label': ClassLabel(names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']),
    'img': Array3D(dtype="int64", shape=(3,32,32)),
    'pixel_values': Array3D(dtype="float32", shape=(3, 224, 224)),
})

preprocessed_train_ds = train_ds.map(preprocess_images, batched=True, features=features)
preprocessed_val_ds = val_ds.map(preprocess_images, batched=True, features=features)
preprocessed_test_ds = test_ds.map(preprocess_images, batched=True, features=features)

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [None]:
# set format to PyTorch
preprocessed_train_ds.set_format('torch', columns=['pixel_values', 'label'])
preprocessed_val_ds.set_format('torch', columns=['pixel_values', 'label'])
preprocessed_test_ds.set_format('torch', columns=['pixel_values', 'label'])

In [None]:
preprocessed_train_ds[0].keys()

dict_keys(['label', 'pixel_values'])

In [None]:
import torch

# create dataloaders
train_batch_size = 2
eval_batch_size = 2
train_dataloader = torch.utils.data.DataLoader(preprocessed_train_ds, batch_size=train_batch_size)
val_dataloader = torch.utils.data.DataLoader(preprocessed_val_ds, batch_size=eval_batch_size)
test_dataloader = torch.utils.data.DataLoader(preprocessed_test_ds, batch_size=eval_batch_size)
batch = next(iter(train_dataloader))

In [None]:
assert batch['pixel_values'].shape == (train_batch_size, 3, 224, 224)
assert batch['label'].shape == (train_batch_size,)

In [None]:
next(iter(val_dataloader))['pixel_values'].shape

torch.Size([2, 3, 224, 224])

## Define model

Here we only replace the head of the checkpoint that was trained on ImageNet. This head has 1000 output neurons, we only have 10.

We can use the handy `ignore_mismatched_sizes` to replace the head. We also set the `id2label` and `label2id` mappings we defined earlier (which will be handy when doing inference).

In [None]:
from transformers import PerceiverForImageClassificationLearned 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = PerceiverForImageClassificationLearned.from_pretrained("nielsr/vision-perceiver",
                                                               num_labels=10,
                                                               id2label=id2label,
                                                               label2id=label2id,
                                                               ignore_mismatched_sizes=True)
model.to(device)

Downloading:   0%|          | 0.00/68.4k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/238M [00:00<?, ?B/s]

Some weights of PerceiverForImageClassificationLearned were not initialized from the model checkpoint at nielsr/vision-perceiver and are newly initialized because the shapes did not match:
- perceiver.decoder.decoder.final_layer.weight: found shape torch.Size([1000, 1024]) in the checkpoint and torch.Size([10, 1024]) in the model instantiated
- perceiver.decoder.decoder.final_layer.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


PerceiverForImageClassificationLearned(
  (perceiver): PerceiverModel(
    (input_preprocessor): PerceiverImagePreprocessor(
      (convnet_1x1): Conv2d(3, 256, kernel_size=(1, 1), stride=(1, 1))
      (position_embeddings): PerceiverTrainablePositionEncoding()
      (positions_projection): Linear(in_features=256, out_features=256, bias=True)
      (conv_after_patches): Identity()
    )
    (embeddings): PerceiverEmbeddings()
    (encoder): PerceiverEncoder(
      (cross_attention): PerceiverLayer(
        (attention): PerceiverAttention(
          (self): PerceiverSelfAttention(
            (layernorm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (layernorm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (query): Linear(in_features=1024, out_features=512, bias=True)
            (key): Linear(in_features=512, out_features=512, bias=True)
            (value): Linear(in_features=512, out_features=512, bias=True)
            (dropout): Dropout(

## Train the model

Here we train the model using native PyTorch.

In [None]:
from transformers import AdamW
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

optimizer = AdamW(model.parameters(), lr=5e-5)

model.train()
for epoch in range(2):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    for batch in tqdm(train_dataloader):
         # get the inputs; 
         inputs = batch["pixel_values"].to(device)
         labels = batch["label"].to(device)

         # zero the parameter gradients
         optimizer.zero_grad()

         # forward + backward + optimize
         outputs = model(inputs=inputs, labels=labels)
         loss = outputs.loss
         loss.backward()
         optimizer.step()

         # evaluate
         predictions = outputs.logits.argmax(-1).cpu().detach().numpy()
         accuracy = accuracy_score(y_true=batch["label"].numpy(), y_pred=predictions)
         print(f"Loss: {loss.item()}, Accuracy: {accuracy}")

Epoch: 0


  0%|          | 0/225 [00:00<?, ?it/s]

Loss: 2.233168125152588, Accuracy: 0.0
Loss: 2.203639030456543, Accuracy: 0.5
Loss: 2.240338087081909, Accuracy: 0.0
Loss: 2.353327989578247, Accuracy: 0.0
Loss: 2.403595447540283, Accuracy: 0.0
Loss: 1.7836918830871582, Accuracy: 0.5
Loss: 2.5166165828704834, Accuracy: 0.0
Loss: 1.8852105140686035, Accuracy: 0.5
Loss: 1.9912769794464111, Accuracy: 0.5
Loss: 2.3983027935028076, Accuracy: 0.5
Loss: 2.259859561920166, Accuracy: 0.5
Loss: 1.768882393836975, Accuracy: 0.5
Loss: 1.9442031383514404, Accuracy: 0.5
Loss: 1.3550140857696533, Accuracy: 1.0
Loss: 1.8230805397033691, Accuracy: 0.5
Loss: 3.0120019912719727, Accuracy: 0.0
Loss: 1.2887790203094482, Accuracy: 0.5
Loss: 1.6598515510559082, Accuracy: 0.5
Loss: 2.4769392013549805, Accuracy: 0.0
Loss: 1.92896568775177, Accuracy: 0.5
Loss: 1.7209858894348145, Accuracy: 1.0
Loss: 1.905043363571167, Accuracy: 0.5
Loss: 1.9800982475280762, Accuracy: 0.5
Loss: 2.3855628967285156, Accuracy: 0.0
Loss: 1.1238913536071777, Accuracy: 1.0
Loss: 1.75

  0%|          | 0/225 [00:00<?, ?it/s]

Loss: 0.03321206569671631, Accuracy: 1.0
Loss: 0.8680737614631653, Accuracy: 0.5
Loss: 0.21213267743587494, Accuracy: 1.0
Loss: 0.17781774699687958, Accuracy: 1.0
Loss: 1.1009070873260498, Accuracy: 0.5
Loss: 0.023570826277136803, Accuracy: 1.0
Loss: 0.08928795158863068, Accuracy: 1.0
Loss: 1.228294849395752, Accuracy: 0.5
Loss: 0.030441101640462875, Accuracy: 1.0
Loss: 0.11671231687068939, Accuracy: 1.0
Loss: 0.8187150955200195, Accuracy: 0.5
Loss: 0.08674780279397964, Accuracy: 1.0
Loss: 0.09608927369117737, Accuracy: 1.0
Loss: 0.1204434409737587, Accuracy: 1.0
Loss: 0.9842737317085266, Accuracy: 0.5
Loss: 0.7288386821746826, Accuracy: 1.0
Loss: 0.04156428575515747, Accuracy: 1.0
Loss: 0.17316851019859314, Accuracy: 1.0
Loss: 0.27705875039100647, Accuracy: 1.0
Loss: 0.36866387724876404, Accuracy: 1.0
Loss: 0.019706543534994125, Accuracy: 1.0
Loss: 0.3954754173755646, Accuracy: 1.0
Loss: 0.07596493512392044, Accuracy: 1.0
Loss: 0.16080008447170258, Accuracy: 1.0
Loss: 0.37121692299842

## Evaluate the model

Finally, we evaluate the model on the test set. We use the Datasets library to compute the accuracy.

In [None]:
from tqdm.notebook import tqdm
from datasets import load_metric

accuracy = load_metric("accuracy")

model.eval()
for batch in tqdm(val_dataloader):
      # get the inputs; 
      inputs = batch["pixel_values"].to(device)
      labels = batch["label"].to(device)

      # forward pass
      outputs = model(inputs=inputs, labels=labels)
      logits = outputs.logits 
      predictions = logits.argmax(-1).cpu().detach().numpy()
      references = batch["label"].numpy()
      accuracy.add_batch(predictions=predictions, references=references)

final_score = accuracy.compute()
print("Accuracy on test set:", final_score)

Downloading:   0%|          | 0.00/1.42k [00:00<?, ?B/s]

  0%|          | 0/25 [00:00<?, ?it/s]

Accuracy on test set: {'accuracy': 0.78}
