In [1]:
from datasets import load_dataset

# load cifar10 (only small portion for demonstration purposes) 
train_ds, test_ds = load_dataset('cifar10', split=['train[:500]', 'test[:200]'])
# split up training into training + validation
splits = train_ds.train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']

Found cached dataset cifar10 (/home/amy/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4)


  0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
id2label = {idx:label for idx,label in enumerate(train_ds.features['label'].names)}
label2id = {label:idx for idx, label in id2label.items()}
print(id2label)
print(label2id)

{0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}
{'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}


In [3]:
from transformers import PerceiverFeatureExtractor

feature_extractor = PerceiverFeatureExtractor()

2023-08-01 20:53:47.334928: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-08-01 20:53:47.463795: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
import numpy as np

def preprocess_images(examples):
    examples['pixel_values'] = feature_extractor(examples['img'], return_tensors="pt").pixel_values
    return examples

In [5]:
# Set the transforms
train_ds.set_transform(preprocess_images)
val_ds.set_transform(preprocess_images)
test_ds.set_transform(preprocess_images)

In [6]:
train_ds[:2]

{'img': [<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32 at 0x7F4480F49900>,
  <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32 at 0x7F4480F494B0>],
 'label': [4, 9],
 'pixel_values': tensor([[[[ 0.0569,  0.0569,  0.0569,  ..., -0.1314, -0.1314, -0.1314],
           [ 0.0569,  0.0569,  0.0569,  ..., -0.1314, -0.1314, -0.1314],
           [ 0.0569,  0.0569,  0.0569,  ..., -0.1486, -0.1486, -0.1486],
           ...,
           [ 1.2557,  1.2557,  1.2557,  ...,  0.0912,  0.0912,  0.0912],
           [ 1.2557,  1.2557,  1.2557,  ...,  0.0912,  0.0912,  0.0912],
           [ 1.2557,  1.2557,  1.2557,  ...,  0.0912,  0.0912,  0.0912]],
 
          [[ 0.3277,  0.3277,  0.3277,  ...,  0.1877,  0.1877,  0.1877],
           [ 0.3277,  0.3277,  0.3277,  ...,  0.1702,  0.1702,  0.1702],
           [ 0.3277,  0.3277,  0.3277,  ...,  0.1702,  0.1702,  0.1702],
           ...,
           [ 1.4482,  1.4482,  1.4307,  ...,  0.3102,  0.3102,  0.3102],
           [ 1.4482,  1.4482,  

In [7]:
from torch.utils.data import DataLoader
import torch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

train_batch_size = 2
eval_batch_size = 2

train_dataloader = DataLoader(train_ds, shuffle=True, collate_fn=collate_fn, batch_size=train_batch_size)
val_dataloader = DataLoader(val_ds, collate_fn=collate_fn, batch_size=eval_batch_size)
test_dataloader = DataLoader(test_ds, collate_fn=collate_fn, batch_size=eval_batch_size)

In [8]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k, v.shape)

pixel_values torch.Size([2, 3, 224, 224])
labels torch.Size([2])


In [9]:
assert batch['pixel_values'].shape == (train_batch_size, 3, 224, 224)
assert batch['labels'].shape == (train_batch_size,)

In [10]:
next(iter(val_dataloader))['pixel_values'].shape

torch.Size([2, 3, 224, 224])

In [11]:
from transformers import PerceiverForImageClassificationLearned 

device = torch.device("cpu")

model = PerceiverForImageClassificationLearned.from_pretrained("deepmind/vision-perceiver-learned",
                                                               num_labels=10,
                                                               id2label=id2label,
                                                               label2id=label2id,
                                                               ignore_mismatched_sizes=True)
model.to(device)

Some weights of PerceiverForImageClassificationLearned were not initialized from the model checkpoint at deepmind/vision-perceiver-learned and are newly initialized because the shapes did not match:
- perceiver.decoder.decoder.final_layer.weight: found shape torch.Size([1000, 1024]) in the checkpoint and torch.Size([10, 1024]) in the model instantiated
- perceiver.decoder.decoder.final_layer.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


PerceiverForImageClassificationLearned(
  (perceiver): PerceiverModel(
    (input_preprocessor): PerceiverImagePreprocessor(
      (convnet_1x1): Conv2d(3, 256, kernel_size=(1, 1), stride=(1, 1))
      (position_embeddings): PerceiverTrainablePositionEncoding()
      (positions_projection): Linear(in_features=256, out_features=256, bias=True)
      (conv_after_patches): Identity()
    )
    (embeddings): PerceiverEmbeddings()
    (encoder): PerceiverEncoder(
      (cross_attention): PerceiverLayer(
        (attention): PerceiverAttention(
          (self): PerceiverSelfAttention(
            (layernorm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (layernorm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (query): Linear(in_features=1024, out_features=512, bias=True)
            (key): Linear(in_features=512, out_features=512, bias=True)
            (value): Linear(in_features=512, out_features=512, bias=True)
            (dropout): Dropout(

In [12]:
from transformers import AdamW
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score


optimizer = AdamW(model.parameters(), lr=5e-5)

model.train()
for epoch in range(2):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    for batch in tqdm(train_dataloader):
         # get the inputs; 
         inputs = batch["pixel_values"].to(device)
         labels = batch["labels"].to(device)

         # zero the parameter gradients
         optimizer.zero_grad()

         # forward + backward + optimize
         outputs = model(inputs=inputs, labels=labels)
         loss = outputs.loss
         loss.backward()
         optimizer.step()

         # evaluate
         predictions = outputs.logits.argmax(-1).cpu().detach().numpy()
         accuracy = accuracy_score(y_true=batch["labels"].numpy(), y_pred=predictions)
         print(f"Loss: {loss.item()}, Accuracy: {accuracy}")

torch.save(model.state_dict(), 'perceiver.pth')

Epoch: 0




  0%|          | 0/225 [00:00<?, ?it/s]

Loss: 2.264371395111084, Accuracy: 0.0
Loss: 2.222069025039673, Accuracy: 0.5
Loss: 2.3770434856414795, Accuracy: 0.0
Loss: 2.4273457527160645, Accuracy: 0.0
Loss: 2.0911569595336914, Accuracy: 0.5
Loss: 2.266533851623535, Accuracy: 0.0
Loss: 2.263072967529297, Accuracy: 0.0
Loss: 2.637988328933716, Accuracy: 0.0
Loss: 2.1883111000061035, Accuracy: 0.0
Loss: 2.240454912185669, Accuracy: 0.5
Loss: 2.4389538764953613, Accuracy: 0.0
Loss: 2.6921257972717285, Accuracy: 0.0
Loss: 2.6161231994628906, Accuracy: 0.0
Loss: 2.940539598464966, Accuracy: 0.0
Loss: 2.5334105491638184, Accuracy: 0.0
Loss: 2.2088370323181152, Accuracy: 0.0
Loss: 1.9955791234970093, Accuracy: 0.0
Loss: 2.206495523452759, Accuracy: 0.0
Loss: 2.9289488792419434, Accuracy: 0.0
Loss: 2.5614700317382812, Accuracy: 0.0
Loss: 2.0577824115753174, Accuracy: 0.5
Loss: 2.6026034355163574, Accuracy: 0.0
Loss: 2.065002202987671, Accuracy: 0.0
Loss: 2.71480655670166, Accuracy: 0.0
Loss: 2.744630813598633, Accuracy: 0.0
Loss: 2.0259

  0%|          | 0/225 [00:00<?, ?it/s]

Loss: 2.089705228805542, Accuracy: 0.5
Loss: 0.007632815279066563, Accuracy: 1.0
Loss: 0.07831308245658875, Accuracy: 1.0
Loss: 0.6276270747184753, Accuracy: 0.5
Loss: 2.671706199645996, Accuracy: 0.5
Loss: 0.5211066603660583, Accuracy: 0.5
Loss: 1.8519608974456787, Accuracy: 0.0
Loss: 0.42271697521209717, Accuracy: 1.0
Loss: 1.0217959880828857, Accuracy: 0.5
Loss: 0.15278173983097076, Accuracy: 1.0
Loss: 0.04442048817873001, Accuracy: 1.0
Loss: 0.3616112172603607, Accuracy: 1.0
Loss: 0.45234644412994385, Accuracy: 1.0
Loss: 1.0401544570922852, Accuracy: 0.5
Loss: 1.8183624744415283, Accuracy: 0.0
Loss: 0.5927532315254211, Accuracy: 1.0
Loss: 0.11321651935577393, Accuracy: 1.0
Loss: 0.054742567241191864, Accuracy: 1.0
Loss: 0.11891348659992218, Accuracy: 1.0
Loss: 0.031357161700725555, Accuracy: 1.0
Loss: 0.358981192111969, Accuracy: 1.0
Loss: 0.1740778088569641, Accuracy: 1.0
Loss: 0.5239387154579163, Accuracy: 1.0
Loss: 0.6655173897743225, Accuracy: 0.5
Loss: 1.472339391708374, Accur

In [15]:
from tqdm.notebook import tqdm
from datasets import load_metric

accuracy = load_metric("accuracy")

  accuracy = load_metric("accuracy")


In [16]:
#model = torch.load(PATH)
from sklearn.metrics import accuracy_score
model.eval()

with torch.no_grad():
      for batch in tqdm(val_dataloader):
            # get the inputs; 
            inputs = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            # forward pass
            outputs = model(inputs=inputs, labels=labels)
            logits = outputs.logits 
            predictions = logits.argmax(-1).cpu().detach().numpy()
            references = batch["labels"].numpy()
            #accuracy = accuracy_score(y_true=batch["labels"].numpy(), y_pred=predictions)
            accuracy.add_batch(predictions=predictions, references=references)

final_score = accuracy.compute()
print("Accuracy on test set:", final_score)

  0%|          | 0/25 [00:00<?, ?it/s]

Accuracy on test set: {'accuracy': 0.8}
