#### Vision Transformer with randomly initialized weights

In [3]:
from transformers import ViTConfig, ViTForImageClassification

# This will instantiate a vision transformer with randomly initialized weights
config = ViTConfig(num_hidden_layers=12, hidden_size=768)
vit_model = ViTForImageClassification(config)

In [4]:
print(config)

ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": true,
  "transformers_version": "4.27.4"
}



#### Vision Transformer with pretrained weigths | Torch [default]

In [6]:
from transformers import ViTForImageClassification
# (default to Torch)
vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

#### Vision Transformer with pretrained weights | Tensorflow

In [1]:
from transformers import TFViTForImageClassification

vit_model = TFViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

All model checkpoint layers were used when initializing TFViTForImageClassification.

All the layers of TFViTForImageClassification were initialized from the model checkpoint at google/vit-base-patch16-224.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFViTForImageClassification for predictions without further training.


#### Feature Extraction | Torch

In [3]:
from transformers import ViTFeatureExtractor

feature_extract = ViTFeatureExtractor()

# we can load it from hub
feature_extract = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')

#### AutoAPI

* The Auto Classes automatically instantiate the appropriate class for you, based on the checkpoint you provide.

In [6]:
from transformers import AutoFeatureExtractor, AutoModelForImageClassification

feature_extract = AutoFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
vit_model = AutoModelForImageClassification.from_pretrained('google/vit-base-patch16-224')

In [24]:
# TENSORFLOW

from transformers import AutoFeatureExtractor, TFAutoModelForImageClassification

feature_extract = AutoFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
vit_model = TFAutoModelForImageClassification.from_pretrained('google/vit-base-patch16-224')

All model checkpoint layers were used when initializing TFViTForImageClassification.

All the layers of TFViTForImageClassification were initialized from the model checkpoint at google/vit-base-patch16-224.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFViTForImageClassification for predictions without further training.


## Image Classification

In [8]:
url = 'https://images.unsplash.com/photo-1596854407944-bf87f6fdd49e?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=580&q=80'

In [45]:
import requests
from PIL import Image

cat = Image.open(requests.get(url, stream=True).raw)
# cat.save('./cache/cat.jpg')

<img src='./cache/cat.jpg'>

#### Prepare input for the model

In [19]:
input_feature = feature_extract(cat, return_tensor='TF')
input_feature.pixel_values[0].shape

(3, 224, 224)

In [31]:
import numpy as np

np.array(input_feature.pixel_values).shape

(1, 3, 224, 224)

#### Prepare outputs from the model

In [34]:
outputs = vit_model(np.array(input_feature.pixel_values))

logits = outputs.logits

In [37]:
np.argmax(logits, axis=-1)

array([285], dtype=int64)

In [38]:
print('PREDICTION', vit_model.config.id2label[285])

PREDICTION Egyptian cat


## Image classification Pipeline

* The `pipeline()` is the easiest and fastest way to use a pretrained model for inference. 

In [40]:
from transformers import pipeline

classifier = pipeline(task='image-classification', model='google/vit-base-patch16-224')

In [41]:
response = classifier(cat)

In [43]:
for res in response:
    print(res)

{'score': 0.4134212136268616, 'label': 'Egyptian cat'}
{'score': 0.27957770228385925, 'label': 'tiger cat'}
{'score': 0.2714625298976898, 'label': 'tabby, tabby cat'}
{'score': 0.011246128007769585, 'label': 'lynx, catamount'}
{'score': 0.0019182711839675903, 'label': 'Persian cat'}
