<a href="https://colab.research.google.com/github/EloneSampaio/postgraduate/blob/master/vision_transformers_tensorflow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Implementação Básica de Vision Transformers no TensorFlow

In [None]:
from transformers import ViTFeatureExtractor, TFAutoModel, ViTForImageClassification, TFAutoModelForImageClassification
from tensorflow.keras.preprocessing import image
import torch
import numpy as np
import tensorflow as tf

### Carregando o modelo pré-treinado ViT
 - Carregar um modelo pré-treinado do ViT para tarefas de classificação de imagens

- Estamos usando TFAutoModel ao invés de um modelo de classificação (como TFAutoModelForSequenceClassification).

- O modelo de classificação retorna diretamente os logits (saída da classificação), enquanto o modelo base retorna os vetores intermediários que estamos interessados em visualizar.

- Quando usamos o modelo base, ele retorna as saídas do tipo BaseModelOutput, que contém o campo last_hidden_state. Este campo contém os vetores de características (embeddings) da imagem.

In [None]:
model = TFAutoModel.from_pretrained('google/vit-base-patch16-224')
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
modelx = TFAutoModelForImageClassification.from_pretrained('google/vit-base-patch16-224')

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFViTModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing TFViTModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFViTModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
Some weights or buffers of the TF 2.0 model TFViTModel were not initialized from the PyTorch model and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
All PyTorch model weights were used when initializing TFViTForImageClassification.

All the weights of TFViTForImageClassification

In [None]:
# Carregando e pré-processando a imagem
img_path =  '/content/drive/MyDrive/Colab Notebooks/rinocerontes.jpg'

img = image.load_img(img_path,target_size=(224,224))
img = image.img_to_array(img)
img = np.expand_dims(img, axis = 0)
inputs = feature_extractor(images=img, return_tensors="tf")

In [None]:
# Fazendo previsão

# Obter as saídas intermediárias do modelo (antes da classificação)
outputs = model(inputs)
outputsx = modelx(inputs)

# A saída "last_hidden_state" contém os vetores de características
last_hidden_state = outputs.last_hidden_state

# O vetor [CLS] (posição 0) pode ser usado como representação da imagem
cls_token_embedding = last_hidden_state[:,0,:]

print("Vetor [CLS] gerado pela ViT: ", cls_token_embedding.numpy())


Vetor [CLS] gerado pela ViT:  [[-4.74241436e-01  1.50001585e+00 -4.59566891e-01  1.23387647e+00
  -6.78003550e-01 -4.62052464e-01  4.13607210e-01 -9.45454955e-01
  -7.79865980e-01 -1.10183254e-01 -1.22581875e+00  1.26881635e+00
  -5.47786891e-01  5.73244512e-01 -1.69266105e-01 -5.66495180e-01
   2.07246244e-01 -2.90045559e-01  8.99169028e-01  9.78285819e-03
  -1.23377538e+00  1.44863164e+00 -6.75288081e-01  2.83313453e-01
   2.35932752e-01  3.06782901e-01  4.00874257e-01  5.08255661e-01
   7.42257714e-01 -1.06808789e-01 -2.63753235e-01 -1.26889479e+00
   2.41408311e-02  8.09072375e-01  4.03772712e-01 -4.35851961e-01
  -1.62648535e+00 -1.64567471e-01 -2.64923394e-01  1.19948700e-01
   8.34790230e-01  4.12109584e-01  3.00727542e-02  4.54543382e-01
   2.40489674e+00 -5.50921559e-01  2.32385933e-01  1.21932197e+00
   1.99685860e+00 -1.08369029e+00  4.52861995e-01  2.00174570e-01
  -5.15188098e-01 -6.73951805e-01 -3.68487090e-04 -1.05395854e-01
   1.56817746e+00  1.17957604e+00 -3.77302319e

In [None]:
outputs = modelx(inputs)
logits = outputs.logits
predicted_class_idx = tf.argmax(logits, axis=-1).numpy()[0]

In [None]:
print("Predicted class: ",{model.config.id2label[predicted_class_idx]})

Predicted class:  {'African elephant, Loxodonta africana'}


In [None]:
print("Predicted class: ",{predicted_class_idx})

Predicted class:  {386}
