In [4]:
!pip install transformers




In [5]:
!pip install datasets


Collecting datasets
  Downloading datasets-2.2.2-py3-none-any.whl (346 kB)
     -------------------------------------- 346.8/346.8 kB 4.3 MB/s eta 0:00:00
Collecting pyarrow>=6.0.0
  Downloading pyarrow-8.0.0-cp37-cp37m-win_amd64.whl (17.8 MB)
     --------------------------------------- 17.8/17.8 MB 26.2 MB/s eta 0:00:00
Collecting fsspec[http]>=2021.05.0
  Downloading fsspec-2022.5.0-py3-none-any.whl (140 kB)
     ---------------------------------------- 140.6/140.6 kB ? eta 0:00:00
Collecting pandas
  Downloading pandas-1.1.5-cp37-cp37m-win_amd64.whl (8.7 MB)
     ---------------------------------------- 8.7/8.7 MB 28.0 MB/s eta 0:00:00
Collecting aiohttp
  Downloading aiohttp-3.8.1-cp37-cp37m-win_amd64.whl (551 kB)
     ------------------------------------- 551.8/551.8 kB 11.5 MB/s eta 0:00:00
Collecting multiprocess
  Downloading multiprocess-0.70.13-py37-none-any.whl (115 kB)
     -------------------------------------- 115.1/115.1 kB 7.0 MB/s eta 0:00:00
Collecting dill<0.3.5
  D

In [1]:
import torch.nn as nn
import numpy as np

from transformers import ViTFeatureExtractor, ViTModel, ViTForImageClassification, TrainingArguments, Trainer, \
    default_data_collator, EarlyStoppingCallback
from transformers.modeling_outputs import SequenceClassifierOutput
from datasets import load_dataset, load_metric, Features, ClassLabel, Array3D, Image


train_ds, test_ds = load_dataset('cifar10', split=['train[:5000]', 'test[:2000]'])
splits = train_ds.train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']

feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
data_collator = default_data_collator


def preprocess_images(examples):
    images = examples['img']
    images = [np.array(image, dtype=np.uint8) for image in images]
    images = [np.moveaxis(image, source=-1, destination=0) for image in images]
    inputs = feature_extractor(images=images)
    examples['pixel_values'] = inputs['pixel_values']

    return examples


# features = Features({
#     'label': ClassLabel(
#         names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']),
#     'img': Array3D(dtype="int64", shape=(3, 32, 32)),
#     'pixel_values': Array3D(dtype="float32", shape=(3, 224, 224)),
# })

features = Features({
    'label': ClassLabel(names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']),
    'img': Image(decode=True, id=None),
    'pixel_values': Array3D(dtype="float32", shape=(3, 224, 224)), })

preprocessed_train_ds = train_ds.map(preprocess_images, batched=True, features=features)
preprocessed_val_ds = val_ds.map(preprocess_images, batched=True, features=features)
preprocessed_test_ds = test_ds.map(preprocess_images, batched=True, features=features)


class ViTForImageClassification2(nn.Module):
    def __init__(self, num_labels=10):
        super(ViTForImageClassification2, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
        self.num_labels = num_labels

    def forward(self, pixel_values, labels):
        outputs = self.vit(pixel_values=pixel_values)
        logits = self.classifier(outputs.last_hidden_state[:, 0])

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


args = TrainingArguments(
    f"test-cifar-10",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=False,
    metric_for_best_model="accuracy",
    logging_dir='logs',
)

# model = ViTForImageClassification()
model = ViTForImageClassification2()


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return load_metric("accuracy").compute(predictions=predictions, references=labels)


trainer = Trainer(
    model,
    args,
    train_dataset=preprocessed_train_ds,
    eval_dataset=preprocessed_val_ds,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

outputs = trainer.predict(preprocessed_test_ds)

  from .autonotebook import tqdm as notebook_tqdm
Reusing dataset cifar10 (C:\Users\Justi\.cache\huggingface\datasets\cifar10\plain_text\1.0.0\447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4)
100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 77.80it/s]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:15<00:00,  3.04s/ba]
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.69s/ba]
Loading cached processed dataset at C:\Users\Justi\.cache\huggingface\datasets\cifar10\plain_text\1.0.0\447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4\cache-741fe2f2646f835f.arrow
The following columns in the training set don't have a corresponding argument in `ViTForImageClassification2.forward` and have been ignored: img. If img are not expected by `ViTForImageClassification2.forward`,  you can safely ignore

Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.525954,0.952
2,1.083200,0.29368,0.948
3,0.253400,0.272527,0.948


The following columns in the evaluation set don't have a corresponding argument in `ViTForImageClassification2.forward` and have been ignored: img. If img are not expected by `ViTForImageClassification2.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 4
Downloading builder script: 4.21kB [00:00, 524kB/s]                                                                    
Saving model checkpoint to test-cifar-10\checkpoint-500
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
The following columns in the evaluation set don't have a corresponding argument in `ViTForImageClassification2.forward` and have been ignored: img. If img are not expected by `ViTForImageClassification2.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 4
Saving model checkpoint to test-cifar-10\checkpoint-1000
Trainer.model is not a `PreTrainedModel`, only saving its s

In [2]:
outputs = trainer.predict(preprocessed_test_ds)
y_pred = outputs.predictions.argmax(1)

The following columns in the test set don't have a corresponding argument in `ViTForImageClassification2.forward` and have been ignored: img. If img are not expected by `ViTForImageClassification2.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 2000
  Batch size = 4


In [5]:
print( outputs.predictions.argmax(1))

[3 8 8 ... 9 8 5]


In [7]:
print(preprocessed_test_ds)

Dataset({
    features: ['label', 'img', 'pixel_values'],
    num_rows: 2000
})


In [12]:
print(outputs[0][0])

[-0.36240196 -0.3699481  -0.33720705  3.7147849  -0.20393686 -0.3067448
 -0.24341322 -0.43666342 -0.25323468 -0.37184802]
