<a href="https://colab.research.google.com/github/Anthonyvijay10/AI-Training/blob/main/Transformers/Fine_tune_Perceiver_for_text_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Set-up environment

As usual, we first install HuggingFace Transformers, and Datasets.

In [1]:
!pip install -q git+https://github.com/huggingface/transformers.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/3.0 MB[0m [31m32.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for transformers (pyproject.toml) ... [?25l[?25hdone


In [2]:
!pip install -q datasets

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/480.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/179.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.3/179.3 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/134.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━

## Prepare data

Here we take a small portion of the IMDB dataset, a binary text classification dataset ("is a movie review positive or negative?").

In [3]:
from datasets import load_dataset

train_ds, test_ds = load_dataset("imdb", split=['train[:10]+train[-10:]', 'test[:5]+test[-5:]'])

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/7.81k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

unsupervised-00000-of-00001.parquet:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

We create id2label and label2id mappings, which are handy at inference time.

In [4]:
labels = train_ds.features['label'].names
print(labels)

['neg', 'pos']


In [5]:
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
print(id2label)

{0: 'neg', 1: 'pos'}


Next, we prepare the data for the model using the tokenizer.

In [6]:
from transformers import PerceiverTokenizer

tokenizer = PerceiverTokenizer.from_pretrained("deepmind/language-perceiver")

train_ds = train_ds.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True),
                        batched=True)
test_ds = test_ds.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True),
                      batched=True)

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

tokenizer_config.json:   0%|          | 0.00/879 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/668 [00:00<?, ?B/s]

Map:   0%|          | 0/20 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

We set the format to PyTorch tensors, and create familiar PyTorch dataloaders.

In [7]:
train_ds.set_format(type="torch", columns=['input_ids', 'attention_mask', 'label'])
test_ds.set_format(type="torch", columns=['input_ids', 'attention_mask', 'label'])

In [8]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_ds, batch_size=4, shuffle=True)
test_dataloader = DataLoader(test_ds, batch_size=4)

Here we verify some things (always important to check out your data!).

In [9]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)

label torch.Size([4])
input_ids torch.Size([4, 2048])
attention_mask torch.Size([4, 2048])


In [10]:
tokenizer.decode(batch['input_ids'][3])

"[CLS]With the mixed reviews this got I wasn't expecting too much, and was pleasantly surprised. It's a very entertaining small crime film with interesting characters, excellent portrayals, writing that's breezy without being glib, and a good pace. It looks good too, in a funky way. Apparently people either like this movie or just hate it, and I'm one who liked it.[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PA

In [11]:
batch['label']

tensor([0, 0, 1, 1])

## Define model

Next, we define our model, and put it on the GPU.

In [12]:
from transformers import PerceiverForSequenceClassification
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = PerceiverForSequenceClassification.from_pretrained("deepmind/language-perceiver",
                                                               num_labels=2,
                                                               id2label=id2label,
                                                               label2id=label2id)
model.to(device)

config.json:   0%|          | 0.00/911 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/804M [00:00<?, ?B/s]

Some weights of PerceiverForSequenceClassification were not initialized from the model checkpoint at deepmind/language-perceiver and are newly initialized: ['perceiver.decoder.decoder.decoding_cross_attention.attention.output.dense.bias', 'perceiver.decoder.decoder.decoding_cross_attention.attention.output.dense.weight', 'perceiver.decoder.decoder.decoding_cross_attention.attention.self.key.bias', 'perceiver.decoder.decoder.decoding_cross_attention.attention.self.key.weight', 'perceiver.decoder.decoder.decoding_cross_attention.attention.self.layernorm1.bias', 'perceiver.decoder.decoder.decoding_cross_attention.attention.self.layernorm1.weight', 'perceiver.decoder.decoder.decoding_cross_attention.attention.self.layernorm2.bias', 'perceiver.decoder.decoder.decoding_cross_attention.attention.self.layernorm2.weight', 'perceiver.decoder.decoder.decoding_cross_attention.attention.self.query.bias', 'perceiver.decoder.decoder.decoding_cross_attention.attention.self.query.weight', 'perceiver.de

PerceiverForSequenceClassification(
  (perceiver): PerceiverModel(
    (input_preprocessor): PerceiverTextPreprocessor(
      (embeddings): Embedding(262, 768)
      (position_embeddings): Embedding(2048, 768)
    )
    (embeddings): PerceiverEmbeddings()
    (encoder): PerceiverEncoder(
      (cross_attention): PerceiverLayer(
        (attention): PerceiverAttention(
          (self): PerceiverSelfAttention(
            (layernorm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
            (layernorm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (query): Linear(in_features=1280, out_features=256, bias=True)
            (key): Linear(in_features=768, out_features=256, bias=True)
            (value): Linear(in_features=768, out_features=1280, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): PerceiverSelfOutput(
            (dense): Linear(in_features=1280, out_features=1280, bias=True)
          )
        )


## Train the model

Here we train the model using native PyTorch.

In [13]:
from transformers import AdamW
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

optimizer = AdamW(model.parameters(), lr=5e-5)

model.train()
for epoch in range(20):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    for batch in tqdm(train_dataloader):
         # get the inputs;
         inputs = batch["input_ids"].to(device)
         attention_mask = batch["attention_mask"].to(device)
         labels = batch["label"].to(device)

         # zero the parameter gradients
         optimizer.zero_grad()

         # forward + backward + optimize
         outputs = model(inputs=inputs, attention_mask=attention_mask, labels=labels)
         loss = outputs.loss
         loss.backward()
         optimizer.step()

         # evaluate
         predictions = outputs.logits.argmax(-1).cpu().detach().numpy()
         accuracy = accuracy_score(y_true=batch["label"].numpy(), y_pred=predictions)
         print(f"Loss: {loss.item()}, Accuracy: {accuracy}")

Epoch: 0




  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.7277336120605469, Accuracy: 0.0
Loss: 2.500199794769287, Accuracy: 0.25
Loss: 2.0635900497436523, Accuracy: 0.0
Loss: 0.686470091342926, Accuracy: 0.5
Loss: 1.6674597263336182, Accuracy: 0.0
Epoch: 1


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.3738338947296143, Accuracy: 0.25
Loss: 0.4841219186782837, Accuracy: 0.75
Loss: 0.7601318359375, Accuracy: 0.5
Loss: 1.8450686931610107, Accuracy: 0.25
Loss: 0.4590146541595459, Accuracy: 1.0
Epoch: 2


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.7238979339599609, Accuracy: 0.5
Loss: 0.7228578925132751, Accuracy: 0.25
Loss: 0.6899157762527466, Accuracy: 0.5
Loss: 0.6036877632141113, Accuracy: 0.75
Loss: 0.9303871393203735, Accuracy: 0.0
Epoch: 3


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.7703682780265808, Accuracy: 0.0
Loss: 0.6849678158760071, Accuracy: 0.5
Loss: 0.9162462949752808, Accuracy: 0.25
Loss: 0.9860324859619141, Accuracy: 0.25
Loss: 0.7530498504638672, Accuracy: 0.5
Epoch: 4


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5022475719451904, Accuracy: 0.75
Loss: 0.5223850607872009, Accuracy: 1.0
Loss: 0.8100591897964478, Accuracy: 0.0
Loss: 0.5376589298248291, Accuracy: 1.0
Loss: 0.45117220282554626, Accuracy: 0.75
Epoch: 5


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.7058809995651245, Accuracy: 0.5
Loss: 0.32386764883995056, Accuracy: 1.0
Loss: 1.1729196310043335, Accuracy: 0.5
Loss: 0.9169785380363464, Accuracy: 0.5
Loss: 0.3909376859664917, Accuracy: 0.75
Epoch: 6


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.705514669418335, Accuracy: 0.5
Loss: 0.6714519262313843, Accuracy: 0.75
Loss: 0.41419628262519836, Accuracy: 1.0
Loss: 0.6535038948059082, Accuracy: 0.25
Loss: 0.6299123764038086, Accuracy: 0.5
Epoch: 7


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.8169611692428589, Accuracy: 0.5
Loss: 0.29011762142181396, Accuracy: 0.75
Loss: 0.3972108066082001, Accuracy: 1.0
Loss: 0.6474295854568481, Accuracy: 0.5
Loss: 0.8594386577606201, Accuracy: 0.5
Epoch: 8


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.8270891904830933, Accuracy: 0.5
Loss: 0.289743572473526, Accuracy: 1.0
Loss: 0.37747830152511597, Accuracy: 0.75
Loss: 0.7093997001647949, Accuracy: 0.5
Loss: 0.3274148404598236, Accuracy: 0.75
Epoch: 9


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.6206012964248657, Accuracy: 0.75
Loss: 0.7472628355026245, Accuracy: 0.5
Loss: 0.28839901089668274, Accuracy: 0.75
Loss: 0.44119876623153687, Accuracy: 0.75
Loss: 0.4060913324356079, Accuracy: 0.75
Epoch: 10


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.22091594338417053, Accuracy: 0.75
Loss: 0.8029991388320923, Accuracy: 0.5
Loss: 0.7295086979866028, Accuracy: 0.5
Loss: 0.387764573097229, Accuracy: 1.0
Loss: 0.41188323497772217, Accuracy: 1.0
Epoch: 11


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.567193329334259, Accuracy: 0.75
Loss: 0.34389716386795044, Accuracy: 0.75
Loss: 0.8517053127288818, Accuracy: 0.5
Loss: 0.2871484160423279, Accuracy: 1.0
Loss: 0.3197101950645447, Accuracy: 0.75
Epoch: 12


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.45436349511146545, Accuracy: 0.75
Loss: 0.46574968099594116, Accuracy: 0.75
Loss: 0.26768288016319275, Accuracy: 1.0
Loss: 0.4200645685195923, Accuracy: 0.75
Loss: 0.8014103174209595, Accuracy: 0.5
Epoch: 13


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.07230489701032639, Accuracy: 1.0
Loss: 0.5049402713775635, Accuracy: 0.75
Loss: 1.8300375938415527, Accuracy: 0.5
Loss: 0.2848842442035675, Accuracy: 1.0
Loss: 0.8164210319519043, Accuracy: 0.75
Epoch: 14


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5933418869972229, Accuracy: 0.75
Loss: 0.5785689353942871, Accuracy: 0.75
Loss: 0.4727745056152344, Accuracy: 1.0
Loss: 0.598716139793396, Accuracy: 0.75
Loss: 0.6283969879150391, Accuracy: 0.25
Epoch: 15


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.3756416141986847, Accuracy: 0.75
Loss: 0.3983192443847656, Accuracy: 0.75
Loss: 0.6982207894325256, Accuracy: 0.5
Loss: 0.9058377742767334, Accuracy: 0.0
Loss: 0.6071950197219849, Accuracy: 0.5
Epoch: 16


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.39755430817604065, Accuracy: 1.0
Loss: 0.47605663537979126, Accuracy: 0.75
Loss: 1.2867577075958252, Accuracy: 0.5
Loss: 0.28596776723861694, Accuracy: 1.0
Loss: 1.0020177364349365, Accuracy: 0.5
Epoch: 17


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.14169429242610931, Accuracy: 1.0
Loss: 0.4839552342891693, Accuracy: 0.75
Loss: 0.776053786277771, Accuracy: 0.5
Loss: 0.28095486760139465, Accuracy: 1.0
Loss: 0.4102437496185303, Accuracy: 1.0
Epoch: 18


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.30409592390060425, Accuracy: 0.75
Loss: 0.5624896883964539, Accuracy: 0.5
Loss: 0.5073601007461548, Accuracy: 0.75
Loss: 0.9168227910995483, Accuracy: 0.25
Loss: 0.39418065547943115, Accuracy: 0.75
Epoch: 19


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.4792799651622772, Accuracy: 0.75
Loss: 0.26577720046043396, Accuracy: 1.0
Loss: 0.4208076000213623, Accuracy: 0.75
Loss: 0.7527266144752502, Accuracy: 0.5
Loss: 0.5673931837081909, Accuracy: 0.75


## Inference

In [16]:
text = "I loved this movie, it's super good."

input_ids = tokenizer(text, return_tensors="pt").input_ids

# forward pass
outputs = model(inputs=input_ids.to(device))
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()

print("Predicted:", model.config.id2label[predicted_class_idx])

Predicted: pos
