In [23]:
from datasets import load_dataset, load_metric
from datasets import Features, ClassLabel, Array3D
from transformers import ViTFeatureExtractor, default_data_collator
from transformers import ViTForImageClassification
from transformers import ViTModel
from transformers import TrainingArguments, Trainer
import torch.nn as nn
from transformers.modeling_outputs import SequenceClassifierOutput
import numpy as np

In [2]:
dataset = load_dataset("imagefolder", data_dir="./")
train_ds = dataset["train"]
test_ds = dataset["test"]

Resolving data files: 100%|██████████| 738/738 [00:00<00:00, 368411.85it/s]
Resolving data files: 100%|██████████| 83/83 [00:00<00:00, 209336.88it/s]
Using custom data configuration default-0d3c020d50177a97
Found cached dataset imagefolder (/home/johann/.cache/huggingface/datasets/imagefolder/default-0d3c020d50177a97/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f)
100%|██████████| 2/2 [00:00<00:00, 302.65it/s]


In [3]:
splits = train_ds.train_test_split(test_size=0.1)

train_ds = splits['train']

val_ds = splits['test']

In [35]:
features = Features({
    'label': ClassLabel(names=['lie', 'run', 'sit', 'walk_stand']),
    'img': Array3D(dtype="int64", shape=(3,32,32)),
    'pixel_values': Array3D(dtype="float32", shape=(3, 224, 224)),
})

train_ds.features['image']

Image(decode=True, id=None)

In [7]:
test_split = test_ds.train_test_split(test_size=0.01)
test_ds = test_split['train']

In [13]:
metric = load_metric("f1")

  metric = load_metric("f1")


In [14]:
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')

In [15]:
def preprocess_images(examples):
    
    images = examples['image']
    images = [np.array(image, dtype=np.uint8) for image in images]
    images = [np.moveaxis(image, source=-1, destination=0) for image in images]
    inputs = feature_extractor(images=images)
    examples['pixel_values'] = inputs['pixel_values']

    return examples

features = Features({
    'label': ClassLabel(names=['lie', 'run', 'sit', 'walk_stand']),
    'img': Array3D(dtype="int64", shape=(3,32,32)),
    'pixel_values': Array3D(dtype="float32", shape=(3, 224, 224)),
})

preprocessed_train_ds = train_ds.map(preprocess_images, batched=True)
preprocessed_val_ds = val_ds.map(preprocess_images, batched=True)
preprocessed_test_ds = test_ds.map(preprocess_images, batched=True)

100%|██████████| 2/2 [00:10<00:00,  5.02s/ba]
100%|██████████| 1/1 [00:00<00:00,  1.25ba/s]
100%|██████████| 1/1 [00:00<00:00,  1.21ba/s]


In [18]:
data_collator = default_data_collator

In [20]:
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
# The ViT model was pretrained on ImageNet-21k, a dataset consisting of 14 million images and 21k classes, 
# and fine-tuned on ImageNet, a dataset consisting of 1 million images and 1k classes.
# print(model.config)

model.train()

ViTForImageClassification(
  (shared_parameters): ModuleDict()
  (vit): ViTModel(
    (shared_parameters): ModuleDict()
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0): ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(
                in_features=768, out_features=768, bias=True
                (loras): ModuleDict()
              )
              (key): Linear(
                in_features=768, out_features=768, bias=True
                (loras): ModuleDict()
              )
              (value): Linear(
                in_features=768, out_features=768, bias=True
                (loras): ModuleDict()
              )
              (dropout): Dropout(p=0.0, inplace=False)
              (p

In [22]:
class ViTForImageClassification2(nn.Module):

    def __init__(self, num_labels=4):

        super(ViTForImageClassification2, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
        self.num_labels = num_labels

    def forward(self, pixel_values, labels):

        outputs = self.vit(pixel_values=pixel_values)
        logits = self.classifier(outputs)
        loss = None

        if labels is not None:

          loss_fct = nn.CrossEntropyLoss()
          loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [25]:
def compute_metrics(eval_pred):

    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

args = TrainingArguments(
    'test-vit',
    evaluation_strategy = "steps",
    learning_rate=2e-5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric,
    logging_dir='logs',
)

trainer = Trainer(
    model,
    args,
    train_dataset = preprocessed_train_ds,
    eval_dataset = preprocessed_val_ds,
    data_collator = data_collator,
    compute_metrics = compute_metrics,
)

using `logging_steps` to initialize `eval_steps` to 500
PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [27]:
# trainer.train()