In [5]:
import imageio.v2 as iio
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
import tensorflow.keras as ks
from transformers import ViTImageProcessor, TFViTForImageClassification

In [2]:
# Create an instance of ViTImageProcessor and use ViT base model as pretrained model
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
# Create an instance of ViTForImageClassification and use ViT base model as pretrained model
model = TFViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model.summary(expand_nested=True)

All PyTorch model weights were used when initializing TFViTForImageClassification.

All the weights of TFViTForImageClassification were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFViTForImageClassification for predictions without further training.


Model: "tf_vi_t_for_image_classification"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 vit (TFViTMainLayer)        multiple                  85798656  
                                                                 
 classifier (Dense)          multiple                  769000    
                                                                 
Total params: 86567656 (330.23 MB)
Trainable params: 86567656 (330.23 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [36]:
def load_image_data(path):
    '''Function will load image data and process it with tensorflow decode_image function.
    Parameters:
    path <str>: file path string to load the image from

    Return:
    <tf.Tensor>: Decoded image with shape (height, width, channels)
    '''
    image = tf.io.read_file(path)
    return tf.io.decode_image(image, channels=3)

@tf.py_function(Tout=tf.float32)
def apply_feature_extractor(x):
    '''Function will apply "feature_extractor" to image data. Decorator allows usage of this function inside tf.data.Dataset.map().
    Parameters:
    x <tf.Tensor>: image data

    Return:
    <tf.Tensor>: features of image data with shape (channels, heigth, width)
    '''
    return feature_extractor.preprocess(images=x, return_tensors='tf')['pixel_values']

os.chdir(os.path.dirname(os.path.abspath(__file__)))
sub_folders = ['/data/Lego', '/data/Duplo']
abs_folders = []

for i, val in enumerate(sub_folders):
    os.chdir('..')
    print(os.getcwd())
    os.chdir(val)
    abs_folders.append(os.getcwd())
## change back to 


# Create tf.data.Dataset containing all ".jpg" file names in folder
image_pipeline = tf.data.Dataset.list_files('*.jpg')
# Adding labels 
image_pipeline = image_pipeline.map(lambda x: (x, 'cat'))
# Load the image date from file names
image_pipeline = image_pipeline.map(lambda x, y: (load_image_data(x), y))

# Wrap call of ViTImageProcessor in batch() / unbatch(), because otherwise it will add a dimension, that is not
# recognized as a batching dimension by the tf.data.Dataset instance.
image_pipeline = image_pipeline.batch(2)
# Apply feature_extractor to all image data, data in pipeline will be "channel dimension first"
image_pipeline = image_pipeline.map(lambda x, y: (apply_feature_extractor(x), y))
image_pipeline = image_pipeline.unbatch()


for i, val in enumerate(image_pipeline):
    print(tf.TensorSpec.from_tensor(val[1]))
image_pipeline = tf.data.Dataset.from_generator(
    image_pipeline.__iter__, output_signature=(
        tf.TensorSpec(shape=(3, 224, 224), dtype=tf.float32),
        tf.TensorSpec(shape=(), dtype=tf.string)
    )
)
# Batching 
image_pipeline = image_pipeline.batch(3)

NameError: name '__file__' is not defined

In [31]:
prediction = model.predict(image_pipeline)
tf.argmax(prediction.logits, axis=1)



<tf.Tensor: shape=(3,), dtype=int64, numpy=array([285, 285, 282])>

In [26]:
# Create tf.data.Dataset containing all ".jpg" file names in folder
test = tf.data.Dataset.list_files('*.jpg')
# Load the image date from file names
test = test.map(lambda x: load_image_data(x))

for i, val in enumerate(test):
    print(val.shape)

(224, 224, 3)
(224, 224, 3)
(224, 224, 3)


In [26]:
for i, val in enumerate(image_pipeline.unbatch()):
    print(val[0].numpy().shape)
    print(val[0].set_shape([224, 224, 3]).shape)
    #plt.imshow(val[0])

(3, 224, 224)


ValueError: Tensor's shape (3, 224, 224) is not compatible with supplied shape [224, 224, 3].

In [75]:
# ToDo : Learn about shapes and that shapes 

<tensorflow.python.data.ops.options.Options at 0x7f5106db8280>

In [11]:
pic = tf.io.read_file('sample.jpg')
pic = tf.io.decode_image(pic)
pic_features = feature_extractor(pic, return_tensors='tf')
prediction = model.predict(pic_features['pixel_values'])
pic_features.keys()



dict_keys(['pixel_values'])

In [7]:
type([])

list