<a href="https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/Perceiver/Fine_tune_Perceiver_for_text_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Set-up environment

As usual, we first install HuggingFace Transformers, and Datasets.

In [None]:
!pip install -q git+https://github.com/huggingface/transformers.git

rm: cannot remove 'transformers': No such file or directory
Cloning into 'transformers'...
remote: Enumerating objects: 95370, done.[K
remote: Counting objects: 100% (17/17), done.[K
remote: Compressing objects: 100% (10/10), done.[K
remote: Total 95370 (delta 5), reused 15 (delta 5), pack-reused 95353[K
Receiving objects: 100% (95370/95370), 72.35 MiB | 28.46 MiB/s, done.
Resolving deltas: 100% (68666/68666), done.
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
[K     |

In [None]:
!pip install -q datasets

[K     |████████████████████████████████| 298 kB 4.2 MB/s 
[K     |████████████████████████████████| 1.1 MB 55.2 MB/s 
[K     |████████████████████████████████| 132 kB 73.8 MB/s 
[K     |████████████████████████████████| 243 kB 72.4 MB/s 
[K     |████████████████████████████████| 160 kB 69.8 MB/s 
[K     |████████████████████████████████| 271 kB 48.6 MB/s 
[K     |████████████████████████████████| 192 kB 49.1 MB/s 
[?25h

## Prepare data

Here we take a small portion of the IMDB dataset, a binary text classification dataset ("is a movie review positive or negative?").

In [None]:
from datasets import load_dataset

train_ds, test_ds = load_dataset("imdb", split=['train[:10]+train[-10:]', 'test[:5]+test[-5:]'])

Downloading:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.05k [00:00<?, ?B/s]

Downloading and preparing dataset imdb/plain_text (download: 80.23 MiB, generated: 127.02 MiB, post-processed: Unknown size, total: 207.25 MiB) to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1...


Downloading:   0%|          | 0.00/84.1M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset imdb downloaded and prepared to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

We create id2label and label2id mappings, which are handy at inference time.

In [None]:
labels = train_ds.features['label'].names
print(labels)

['neg', 'pos']


In [None]:
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
print(id2label)

{0: 'neg', 1: 'pos'}


Next, we prepare the data for the model using the tokenizer. 

In [None]:
from transformers import PerceiverTokenizer

tokenizer = PerceiverTokenizer.from_pretrained("deepmind/language-perceiver")

train_ds = train_ds.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True),
                        batched=True)
test_ds = test_ds.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True),
                      batched=True)

Downloading:   0%|          | 0.00/823 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/911 [00:00<?, ?B/s]

Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.


  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

We set the format to PyTorch tensors, and create familiar PyTorch dataloaders.

In [None]:
train_ds.set_format(type="torch", columns=['input_ids', 'attention_mask', 'label'])
test_ds.set_format(type="torch", columns=['input_ids', 'attention_mask', 'label'])

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_ds, batch_size=4, shuffle=True)
test_dataloader = DataLoader(test_ds, batch_size=4)

Here we verify some things (always important to check out your data!).

In [None]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)

attention_mask torch.Size([4, 2048])
input_ids torch.Size([4, 2048])
label torch.Size([4])


In [None]:
tokenizer.decode(batch['input_ids'][3])

"Very smart, sometimes shocking, I just love it. It shoved one more side of David's brilliant talent. He impressed me greatly! David is the best. The movie captivates your attention for every second.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><

In [None]:
batch['label']

tensor([0, 0, 0, 1])

## Define model

Next, we define our model, and put it on the GPU.

In [None]:
from transformers import PerceiverForSequenceClassification
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = PerceiverForSequenceClassification.from_pretrained("deepmind/language-perceiver",
                                                               num_labels=2,
                                                               id2label=id2label,
                                                               label2id=label2id)
model.to(device)

Downloading:   0%|          | 0.00/767M [00:00<?, ?B/s]

Some weights of the model checkpoint at deepmind/language-perceiver were not used when initializing PerceiverForSequenceClassification: ['perceiver.decoder.decoding_cross_attention.attention.self.layernorm1.bias', 'perceiver.decoder.decoding_cross_attention.mlp.dense1.bias', 'perceiver.decoder.decoding_cross_attention.attention.output.dense.bias', 'perceiver.decoder.decoding_cross_attention.attention.self.key.weight', 'perceiver.decoder.decoding_cross_attention.attention.self.layernorm1.weight', 'perceiver.decoder.decoding_cross_attention.attention.self.query.bias', 'perceiver.decoder.decoding_cross_attention.attention.self.key.bias', 'perceiver.decoder.decoding_cross_attention.attention.self.layernorm2.bias', 'embedding_decoder.bias', 'perceiver.decoder.output_position_encodings.position_embeddings', 'perceiver.decoder.decoding_cross_attention.mlp.dense2.bias', 'perceiver.decoder.decoding_cross_attention.mlp.dense1.weight', 'perceiver.decoder.decoding_cross_attention.attention.output.

PerceiverForSequenceClassification(
  (perceiver): PerceiverModel(
    (input_preprocessor): PerceiverTextPreprocessor(
      (embeddings): Embedding(262, 768)
      (position_embeddings): Embedding(2048, 768)
    )
    (embeddings): PerceiverEmbeddings()
    (encoder): PerceiverEncoder(
      (cross_attention): PerceiverLayer(
        (attention): PerceiverAttention(
          (self): PerceiverSelfAttention(
            (layernorm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
            (layernorm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (query): Linear(in_features=1280, out_features=256, bias=True)
            (key): Linear(in_features=768, out_features=256, bias=True)
            (value): Linear(in_features=768, out_features=1280, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): PerceiverSelfOutput(
            (dense): Linear(in_features=1280, out_features=1280, bias=True)
          )
        )


## Train the model

Here we train the model using native PyTorch.

In [None]:
from transformers import AdamW
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

optimizer = AdamW(model.parameters(), lr=5e-5)

model.train()
for epoch in range(20):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    for batch in tqdm(train_dataloader):
         # get the inputs; 
         inputs = batch["input_ids"].to(device)
         attention_mask = batch["attention_mask"].to(device)
         labels = batch["label"].to(device)

         # zero the parameter gradients
         optimizer.zero_grad()

         # forward + backward + optimize
         outputs = model(inputs=inputs, attention_mask=attention_mask, labels=labels)
         loss = outputs.loss
         loss.backward()
         optimizer.step()

         # evaluate
         predictions = outputs.logits.argmax(-1).cpu().detach().numpy()
         accuracy = accuracy_score(y_true=batch["label"].numpy(), y_pred=predictions)
         print(f"Loss: {loss.item()}, Accuracy: {accuracy}")

Epoch: 0


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.8304584622383118, Accuracy: 0.25
Loss: 0.9324597120285034, Accuracy: 0.75
Loss: 2.6381170749664307, Accuracy: 0.0
Loss: 0.7451722025871277, Accuracy: 0.5
Loss: 0.7533012628555298, Accuracy: 0.5
Epoch: 1


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.2992808818817139, Accuracy: 0.25
Loss: 1.0224494934082031, Accuracy: 0.25
Loss: 0.7040843963623047, Accuracy: 0.25
Loss: 0.6907804012298584, Accuracy: 0.5
Loss: 0.7203900814056396, Accuracy: 0.25
Epoch: 2


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5046183466911316, Accuracy: 1.0
Loss: 0.8854724168777466, Accuracy: 0.5
Loss: 0.938153862953186, Accuracy: 0.5
Loss: 0.9919759035110474, Accuracy: 0.25
Loss: 0.686953604221344, Accuracy: 0.5
Epoch: 3


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5484501719474792, Accuracy: 0.75
Loss: 1.2984037399291992, Accuracy: 0.25
Loss: 1.0918217897415161, Accuracy: 0.25
Loss: 0.551606297492981, Accuracy: 0.75
Loss: 0.6901388764381409, Accuracy: 0.5
Epoch: 4


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.6505652070045471, Accuracy: 0.75
Loss: 0.763052225112915, Accuracy: 0.25
Loss: 0.5942457914352417, Accuracy: 0.75
Loss: 0.7520264387130737, Accuracy: 0.25
Loss: 0.6076421737670898, Accuracy: 0.75
Epoch: 5


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.6453779935836792, Accuracy: 0.75
Loss: 0.6081874370574951, Accuracy: 0.75
Loss: 0.6792967319488525, Accuracy: 0.5
Loss: 0.2903488576412201, Accuracy: 1.0
Loss: 1.3074119091033936, Accuracy: 0.5
Epoch: 6


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5321564078330994, Accuracy: 0.75
Loss: 0.42676296830177307, Accuracy: 1.0
Loss: 0.9039682745933533, Accuracy: 0.5
Loss: 0.3487512767314911, Accuracy: 0.75
Loss: 0.9425325989723206, Accuracy: 0.25
Epoch: 7


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.7495611906051636, Accuracy: 0.5
Loss: 0.7849998474121094, Accuracy: 0.5
Loss: 0.5380256175994873, Accuracy: 0.75
Loss: 1.2848037481307983, Accuracy: 0.25
Loss: 0.5453243255615234, Accuracy: 0.75
Epoch: 8


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.6200200915336609, Accuracy: 0.5
Loss: 0.8176119327545166, Accuracy: 0.5
Loss: 0.7581350207328796, Accuracy: 0.5
Loss: 0.6954535841941833, Accuracy: 0.5
Loss: 0.7155317068099976, Accuracy: 0.25
Epoch: 9


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5425307750701904, Accuracy: 0.75
Loss: 0.4797278344631195, Accuracy: 1.0
Loss: 0.3875162601470947, Accuracy: 1.0
Loss: 0.4902285039424896, Accuracy: 0.75
Loss: 0.7444762587547302, Accuracy: 0.25
Epoch: 10


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.9322880506515503, Accuracy: 0.5
Loss: 0.6616388559341431, Accuracy: 0.75
Loss: 0.7554448246955872, Accuracy: 0.5
Loss: 0.40487349033355713, Accuracy: 1.0
Loss: 0.4681805968284607, Accuracy: 0.75
Epoch: 11


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.4545046091079712, Accuracy: 0.75
Loss: 0.7147601246833801, Accuracy: 0.5
Loss: 0.5335186719894409, Accuracy: 0.5
Loss: 0.26640114188194275, Accuracy: 1.0
Loss: 0.7020363211631775, Accuracy: 0.5
Epoch: 12


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.9181938171386719, Accuracy: 0.5
Loss: 0.24263136088848114, Accuracy: 1.0
Loss: 0.5084943771362305, Accuracy: 0.5
Loss: 0.2987772822380066, Accuracy: 0.75
Loss: 0.5919318199157715, Accuracy: 0.75
Epoch: 13


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5856934785842896, Accuracy: 0.5
Loss: 0.3123701810836792, Accuracy: 1.0
Loss: 0.5210100412368774, Accuracy: 0.75
Loss: 0.24680882692337036, Accuracy: 1.0
Loss: 0.4321930706501007, Accuracy: 1.0
Epoch: 14


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.02089816704392433, Accuracy: 1.0
Loss: 0.5982148051261902, Accuracy: 0.75
Loss: 0.3422252833843231, Accuracy: 1.0
Loss: 0.46750909090042114, Accuracy: 0.75
Loss: 0.15922528505325317, Accuracy: 1.0
Epoch: 15


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.1508100926876068, Accuracy: 1.0
Loss: 0.6225242614746094, Accuracy: 0.75
Loss: 0.05214595049619675, Accuracy: 1.0
Loss: 0.024092189967632294, Accuracy: 1.0
Loss: 0.051527444273233414, Accuracy: 1.0
Epoch: 16


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.04499734565615654, Accuracy: 1.0
Loss: 0.015084860846400261, Accuracy: 1.0
Loss: 0.42504045367240906, Accuracy: 0.75
Loss: 0.4822450280189514, Accuracy: 0.75
Loss: 5.076355934143066, Accuracy: 0.25
Epoch: 17


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.05873966217041, Accuracy: 0.25
Loss: 0.08333556354045868, Accuracy: 1.0
Loss: 1.191737174987793, Accuracy: 0.5
Loss: 1.0786267518997192, Accuracy: 0.5
Loss: 0.743228554725647, Accuracy: 0.5
Epoch: 18


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5613749027252197, Accuracy: 0.75
Loss: 0.3839970827102661, Accuracy: 0.75
Loss: 1.4717833995819092, Accuracy: 0.25
Loss: 0.9521416425704956, Accuracy: 0.25
Loss: 0.40143632888793945, Accuracy: 0.75
Epoch: 19


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.640500009059906, Accuracy: 0.75
Loss: 0.34529292583465576, Accuracy: 1.0
Loss: 0.16599130630493164, Accuracy: 1.0
Loss: 0.49584537744522095, Accuracy: 0.75
Loss: 0.42844119668006897, Accuracy: 0.75


## Evaluate the model

Finally, we evaluate the model on the test set. We use the Datasets library to compute the accuracy.

In [None]:
from tqdm.notebook import tqdm
from datasets import load_metric

accuracy = load_metric("accuracy")

model.eval()
for batch in tqdm(test_dataloader):
      # get the inputs; 
      inputs = batch["input_ids"].to(device)
      attention_mask = batch["attention_mask"].to(device)
      labels = batch["label"].to(device)

      # forward pass
      outputs = model(inputs=inputs, attention_mask=attention_mask)
      logits = outputs.logits 
      predictions = logits.argmax(-1).cpu().detach().numpy()
      references = batch["label"].numpy()
      accuracy.add_batch(predictions=predictions, references=references)

final_score = accuracy.compute()
print("Accuracy on test set:", final_score)

## Inference

In [None]:
text = "I loved this movie, it's super good."

input_ids = tokenizer(text, return_tensors="pt").input_ids

# forward pass
outputs = model(inputs=input_ids.to(device))
logits = outputs.logits 
predicted_class_idx = logits.argmax(-1).item()

print("Predicted:", model.config.id2label[predicted_class_idx])