![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/jupyter/prediction/english/ViTForImageClassification.ipynb)

## ViTForImageClassification Annotator

In this notebok we are going to classify images using spark-nlp

### Downloading Images

In [4]:
!mkdir images

In [5]:
import requests

spark_nlp_repo = "https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp"
image_resources = spark_nlp_repo + "/master/image-classification-vit/src/test/resources/image/"
images = ["bluetick.jpg", "tractor.JPEG", "chihuahua.jpg", "palace.JPEG",
          "hippopotamus.JPEG", "hen.JPEG"]

In [6]:
for image in images:
    image_uri = image_resources + image
    response = requests.get(image_uri)
    open("./images/" + image, 'wb').write(response.content)

In [8]:
import sparknlp
from sparknlp.base import *
from sparknlp.annotator import *
from pyspark.sql import SparkSession

In [None]:
spark = sparknlp.start()

In [10]:
data_df = spark.read.format("image") \
            .load(path="images/")

### Pipeline with ViTForImageClassification

In [12]:
image_assembler = ImageAssembler() \
            .setInputCol("image") \
            .setOutputCol("image_assembler")

image_classifier = ViTForImageClassification \
    .pretrained() \
    .setInputCols("image_assembler") \
    .setOutputCol("class")

pipeline = Pipeline(stages=[
    image_assembler,
    image_classifier,
])

image_classifier_vit_base_patch16_224 download started this may take some time.
Approximate size to download 309.7 MB
[OK!]


In [13]:
model = pipeline.fit(data_df)

In [14]:
image_df = model.transform(data_df)
image_df.show(2)

+--------------------+--------------------+--------------------+
|               image|     image_assembler|               class|
+--------------------+--------------------+--------------------+
|[file:///content/...|[[image, file:///...|[[category, 0, 5,...|
|[file:///content/...|[[image, file:///...|[[category, 0, 55...|
+--------------------+--------------------+--------------------+
only showing top 2 rows



### Light Pipeline

To use light pipeline in ViT transformer, we need to use the new method `fullAnnotateImage`, which can receive 3 kind of inputs:
1. A path to a single image
2. A path to a list of images
3. A path to a directory of images

In [15]:
light_pipeline = LightPipeline(model)
annotations_result = light_pipeline.fullAnnotateImage("images/hippopotamus.JPEG")
annotations_result[0].keys()

dict_keys(['image_assembler', 'class'])

In [16]:
for result in annotations_result:
  print(result['image_assembler'])
  print(result['class'])

[AnnotationImage(image, images/hippopotamus.JPEG, 333, 500, 3, 16, b'\x84~k\x81}e\x81\x7fg|xe\x80}n\x87\x83q\x90\x8bv\x8b\x84p\x81\x84j{\x7ffxx`wujembS^VZ^cIRO:FFHSPKXPESGR_WX_\\[XSibQ\x83{j\x8b\x88s\x85\x84o\x89\x87u\x92\x91\x83\x83\x85r\x84\x81l\xab\x9d\x8a\xa0\x9f\x83\xa9\x9e\x8a\xb4\xa1\x8c\xba\xa6\x87\xb8\xa2\x7f\xa9\x8fq\xab\x98}\xa6\x9e\x80\xb2\x9b{\x94\x90tsxcakZHXFWaQW[PHPIDK>JS?KOJJKGOPFHIE6?5-;0\x1c \r\x1b\x1e\x0f43)/3(&0#(6* /\'\x1c\'$\x0b\x1e\x15\x05\x14\x0c\r\x1b\x19\t\x1e\x1b\x00\x14\x0f\x05\x0b\n\x05\x03\x03\x13\x13\x05\x11\x1f\x13\x00\x11\t\x03\t\x08\x07\x08\x06\x0e\x1c\x11\x14%\x1a\x00\x15\x0c\x0b\x1b\x10\x14&\x1f -%\x1c)!\x1d+%\x1e% \x14\x1b\x16\x14\x1f\x15\'*!\x12!\x13\x15&#\x1c#&\x00\x0e\n\x02\x10\x04\n\x14\x08\x03\x12\x0e\x05\x11\x11\x0e\x1c\x1a\x15\x1d\x1c\r\x18\x15\r\x1a\x12\x1c$\x1d\x19"\x1f\x17" \x1a\x1f\x1e #\x1a1:09@9=A<=D?FJD?<4<=3>F?>>.GOD>KC7>/7>;:GI*1$+%\x18A<-C?-=:,86+22 23\x1f2.#23#42 #$\x14\x11\x17\x0c\x16\x1f\x1c$1/\x1f(%\x17\x1b\x16\x13"\x1e\x18-+\x

To send a list of images, we just difine a set of images

In [21]:
images = ["images/bluetick.jpg", "images/palace.JPEG", "images/hen.JPEG"]
annotations_result = light_pipeline.fullAnnotateImage(images)
annotations_result[0].keys()

dict_keys(['image_assembler', 'class'])

In [24]:
for result in annotations_result:
  print(result['class'])

[Annotation(category, 0, 7, bluetick, Map(nChannels -> 3, Some(lumbermill, sawmill) -> 1.3846728E-6, Some(beer glass) -> 1.1807944E-6, image -> 0, Some(damselfly) -> 3.6875622E-7, Some(turnstile) -> 2.023695E-6, Some(cockroach, roach) -> 6.2982855E-7, height -> 500, Some(bulbul) -> 5.417509E-7, Some(sea snake) -> 5.7421556E-7, origin -> images/bluetick.jpg, Some(mixing bowl) -> 5.4001305E-7, mode -> 16, None -> 4.5454306E-7, Some(whippet) -> 1.2101438E-6, width -> 333, Some(buckle) -> 1.1306514E-6))]
[Annotation(category, 0, 5, palace, Map(nChannels -> 3, Some(lumbermill, sawmill) -> 6.3918545E-5, Some(beer glass) -> 8.879939E-6, image -> 0, Some(damselfly) -> 9.565577E-6, Some(turnstile) -> 6.315168E-5, Some(cockroach, roach) -> 1.125408E-5, height -> 334, Some(bulbul) -> 3.321073E-5, Some(sea snake) -> 1.0886038E-5, origin -> images/palace.JPEG, Some(mixing bowl) -> 2.6202975E-5, mode -> 16, None -> 2.6134943E-5, Some(whippet) -> 1.3805137E-5, width -> 500, Some(buckle) -> 3.121459E-

Or we can simply send a directory that contains all the images we need

In [19]:
annotations_result = light_pipeline.fullAnnotateImage("images/")
annotations_result[0].keys()

dict_keys(['image_assembler', 'class'])

In [20]:
for result in annotations_result:
  print(result['class'])

[Annotation(category, 0, 8, Chihuahua, Map(nChannels -> 3, Some(lumbermill, sawmill) -> 2.093878E-7, Some(beer glass) -> 2.577504E-7, image -> 0, Some(damselfly) -> 7.065122E-8, Some(turnstile) -> 5.03293E-7, Some(cockroach, roach) -> 4.3412697E-7, height -> 500, Some(bulbul) -> 3.4132438E-7, Some(sea snake) -> 1.5207818E-6, origin -> images/chihuahua.jpg, Some(mixing bowl) -> 2.208622E-7, mode -> 16, None -> 3.1694032E-7, Some(whippet) -> 1.1523372E-5, width -> 333, Some(buckle) -> 3.492674E-6))]
[Annotation(category, 0, 5, palace, Map(nChannels -> 3, Some(lumbermill, sawmill) -> 6.3918545E-5, Some(beer glass) -> 8.879939E-6, image -> 0, Some(damselfly) -> 9.565577E-6, Some(turnstile) -> 6.315168E-5, Some(cockroach, roach) -> 1.125408E-5, height -> 334, Some(bulbul) -> 3.321073E-5, Some(sea snake) -> 1.0886038E-5, origin -> images/palace.JPEG, Some(mixing bowl) -> 2.6202975E-5, mode -> 16, None -> 2.6134943E-5, Some(whippet) -> 1.3805137E-5, width -> 500, Some(buckle) -> 3.121459E-5))