In [1]:
from transformers import ViTImageProcessor

model_name = 'google/vit-base-patch16-224-in21k'
processor = ViTImageProcessor.from_pretrained(model_name)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image

instances_file = "./instances.csv"
dataset_dir = "./dataset"

df = pd.read_csv(instances_file)

class COCODataset(Dataset):
    def __init__(self, instances, img_dir, processor):
        self.instances = instances
        self.img_dir = img_dir
        self.processor = processor

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.instances.iloc[idx, 0])
        img = Image.open(img_path).convert("RGB")
        label = int(self.instances.iloc[idx, 1])

        inputs = self.processor(img, return_tensors='pt')
        return {
            "pixel_values": inputs["pixel_values"].squeeze(),
            "label": torch.tensor(label, dtype=torch.float)
        }

In [3]:
dataset = COCODataset(df, dataset_dir, processor)
print(len(dataset))
print(dataset[0])

25000
{'pixel_values': tensor([[[-0.5137, -0.4510, -0.5294,  ...,  0.3412,  0.3412,  0.3255],
         [-0.4588, -0.4588, -0.4745,  ...,  0.3255,  0.3412,  0.3412],
         [-0.4510, -0.4039, -0.4431,  ...,  0.2706,  0.2706,  0.2784],
         ...,
         [-0.6157, -0.5686, -0.5765,  ..., -0.5843, -0.5765, -0.5529],
         [-0.5765, -0.5137, -0.5529,  ..., -0.5059, -0.5216, -0.5059],
         [-0.5765, -0.5529, -0.5922,  ..., -0.5373, -0.5608, -0.5529]],

        [[-0.2706, -0.2078, -0.2941,  ...,  0.3333,  0.3333,  0.3176],
         [-0.1843, -0.1922, -0.2078,  ...,  0.3333,  0.3569,  0.3490],
         [-0.1373, -0.0902, -0.1294,  ...,  0.3098,  0.3255,  0.3255],
         ...,
         [-0.4745, -0.4353, -0.4353,  ..., -0.4824, -0.4745, -0.4431],
         [-0.4588, -0.3882, -0.4118,  ..., -0.4118, -0.4118, -0.3961],
         [-0.4745, -0.4431, -0.4588,  ..., -0.4431, -0.4588, -0.4510]],

        [[-0.0039,  0.0588, -0.0118,  ...,  0.3333,  0.3255,  0.3020],
         [ 0.0902,  0.

In [4]:
from transformers import ViTForImageClassification

labels = ["Bicycle"]

model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./results",
  per_device_train_batch_size=16,
  eval_strategy="no",
  num_train_epochs=4,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=100,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  # report_to='tensorboard',
  # load_best_model_at_end=True,
)

In [6]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    tokenizer=processor
)

  trainer = Trainer(


In [7]:
train_results = trainer.train()
trainer.save_model()

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Step,Training Loss
100,0.1054
200,0.09
300,0.0789
400,0.0805
500,0.076
600,0.0781
700,0.077
800,0.0692
900,0.0759
1000,0.0802


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return Variable._execution_engine.run_backward(  # Calls into 