![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp/blob/master/examples/python/annotation/image/ViTForImageClassification.ipynb)

In [None]:
!wget https://setup.johnsnowlabs.com/colab.sh -O - | bash /dev/stdin -p 3.2.1 -s 4.1.0

## ViTForImageClassification Annotator

In this notebok we are going to classify images using spark-nlp.

### Downloading Images

In [None]:
!wget -q https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/images/images.zip

In [None]:
import shutil
shutil.unpack_archive("images.zip", "images", "zip")

### Start Spark Session

In [None]:
import sparknlp
from sparknlp.base import *
from sparknlp.annotator import *
from pyspark.sql import SparkSession

In [None]:
spark = sparknlp.start()

In [None]:
data_df = spark.read.format("image").option("dropInvalid", value = True).load(path="images/images/")

### Pipeline with ViTForImageClassification

In [None]:
image_assembler = ImageAssembler() \
            .setInputCol("image") \
            .setOutputCol("image_assembler")

image_classifier = ViTForImageClassification \
    .pretrained() \
    .setInputCols("image_assembler") \
    .setOutputCol("class")

pipeline = Pipeline(stages=[
    image_assembler,
    image_classifier,
])

image_classifier_vit_base_patch16_224 download started this may take some time.
Approximate size to download 309.7 MB
[OK!]


In [None]:
model = pipeline.fit(data_df)

In [None]:
image_df = model.transform(data_df)
image_df.show()

+--------------------+--------------------+--------------------+
|               image|     image_assembler|               class|
+--------------------+--------------------+--------------------+
|{file:///content/...|[{image, file:///...|[{category, 0, 5,...|
|{file:///content/...|[{image, file:///...|[{category, 0, 11...|
|{file:///content/...|[{image, file:///...|[{category, 0, 55...|
|{file:///content/...|[{image, file:///...|[{category, 0, 2,...|
|{file:///content/...|[{image, file:///...|[{category, 0, 24...|
|{file:///content/...|[{image, file:///...|[{category, 0, 14...|
|{file:///content/...|[{image, file:///...|[{category, 0, 7,...|
|{file:///content/...|[{image, file:///...|[{category, 0, 8,...|
|{file:///content/...|[{image, file:///...|[{category, 0, 6,...|
|{file:///content/...|[{image, file:///...|[{category, 0, 1,...|
+--------------------+--------------------+--------------------+



### Light Pipeline

To use light pipeline in ViT transformer, we need to use the new method `fullAnnotateImage`, which can receive 3 kind of inputs:
1. A path to a single image
2. A path to a list of images

In [None]:
light_pipeline = LightPipeline(model)
annotations_result = light_pipeline.fullAnnotateImage("images/images/hippopotamus.JPEG")
annotations_result[0].keys()

dict_keys(['image_assembler', 'class'])

In [None]:
for result in annotations_result:
  image_assembler = result['image_assembler'][0]
  print(f"annotator_type: {image_assembler.annotator_type}")
  print(f"origin: {image_assembler.origin}")
  print(f"height: {image_assembler.height}")
  print(f"width: {image_assembler.width}")
  print(f"nChannels: {image_assembler.nChannels}")
  print(f"mode: {image_assembler.mode}")
  print(f"result size: {str(len(image_assembler.result))}")
  print(f"metadata: {image_assembler.metadata}")
  print(result['class'])

To send a list of images, we just difine a set of images

In [None]:
images = ["images/images/bluetick.jpg", "images/images/palace.JPEG", "images/images/hen.JPEG"]
annotations_result = light_pipeline.fullAnnotateImage(images)
annotations_result[0].keys()

dict_keys(['image_assembler', 'class'])

In [None]:
for result in annotations_result:
  print(result['class'])

[Annotation(category, 0, 7, bluetick, Map(nChannels -> 3, Some(lumbermill, sawmill) -> 1.3846728E-6, Some(beer glass) -> 1.1807944E-6, image -> 0, Some(damselfly) -> 3.6875622E-7, Some(turnstile) -> 2.023695E-6, Some(cockroach, roach) -> 6.2982855E-7, height -> 500, Some(bulbul) -> 5.417509E-7, Some(sea snake) -> 5.7421556E-7, origin -> images/images/bluetick.jpg, Some(mixing bowl) -> 5.4001305E-7, mode -> 16, None -> 4.5454306E-7, Some(whippet) -> 1.2101438E-6, width -> 333, Some(buckle) -> 1.1306514E-6))]
[Annotation(category, 0, 5, palace, Map(nChannels -> 3, Some(lumbermill, sawmill) -> 6.3918545E-5, Some(beer glass) -> 8.879939E-6, image -> 0, Some(damselfly) -> 9.565577E-6, Some(turnstile) -> 6.315168E-5, Some(cockroach, roach) -> 1.125408E-5, height -> 334, Some(bulbul) -> 3.321073E-5, Some(sea snake) -> 1.0886038E-5, origin -> images/images/palace.JPEG, Some(mixing bowl) -> 2.6202975E-5, mode -> 16, None -> 2.6134943E-5, Some(whippet) -> 1.3805137E-5, width -> 500, Some(buckle)