![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/jupyter/prediction/english/ViTForImageClassification.ipynb)

In [None]:
!wget https://setup.johnsnowlabs.com/colab.sh -O - | bash /dev/stdin -p 3.2.1 -s 4.1.0

## ViTForImageClassification Annotator

In this notebok we are going to classify images using spark-nlp

### Downloading Images

In [4]:
!mkdir images

In [5]:
import requests

spark_nlp_repo = "https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp"
image_resources = spark_nlp_repo + "/master/image-classification-vit/src/test/resources/image/"
images = ["bluetick.jpg", "tractor.JPEG", "chihuahua.jpg", "palace.JPEG",
          "hippopotamus.JPEG", "hen.JPEG"]

In [6]:
for image in images:
    image_uri = image_resources + image
    response = requests.get(image_uri)
    open("./images/" + image, 'wb').write(response.content)

In [8]:
import sparknlp
from sparknlp.base import *
from sparknlp.annotator import *
from pyspark.sql import SparkSession

In [9]:
spark = sparknlp.start()

Apache Spark version: 3.0.2


In [10]:
data_df = spark.read.format("image") \
            .load(path="images/")

### Pipeline with ViTForImageClassification

In [12]:
image_assembler = ImageAssembler() \
            .setInputCol("image") \
            .setOutputCol("image_assembler")

image_classifier = ViTForImageClassification \
    .pretrained() \
    .setInputCols("image_assembler") \
    .setOutputCol("class")

pipeline = Pipeline(stages=[
    image_assembler,
    image_classifier,
])

image_classifier_vit_base_patch16_224 download started this may take some time.
Approximate size to download 309.7 MB
[OK!]


In [13]:
model = pipeline.fit(data_df)

In [14]:
image_df = model.transform(data_df)
image_df.show(2)

+--------------------+--------------------+--------------------+
|               image|     image_assembler|               class|
+--------------------+--------------------+--------------------+
|[file:///content/...|[[image, file:///...|[[category, 0, 5,...|
|[file:///content/...|[[image, file:///...|[[category, 0, 55...|
+--------------------+--------------------+--------------------+
only showing top 2 rows



### Light Pipeline

To use light pipeline in ViT transformer, we need to use the new method `fullAnnotateImage`, which can receive 3 kind of inputs:
1. A path to a single image
2. A path to a list of images
3. A path to a directory of images

In [15]:
light_pipeline = LightPipeline(model)
annotations_result = light_pipeline.fullAnnotateImage("images/hippopotamus.JPEG")
annotations_result[0].keys()

dict_keys(['image_assembler', 'class'])

In [16]:
for result in annotations_result:
  image_assembler = result['image_assembler'][0]
  print(f"annotator_type: {image_assembler.annotator_type}")
  print(f"origin: {image_assembler.origin}")
  print(f"height: {image_assembler.height}")
  print(f"width: {image_assembler.width}")
  print(f"nChannels: {image_assembler.nChannels}")
  print(f"mode: {image_assembler.mode}")
  print(f"result size: {str(len(image_assembler.result))}")
  print(f"metadata: {image_assembler.metadata}")
  print(result['class'])

annotator_type: image
origin: images/hippopotamus.JPEG
height: 333
width: 500
nChannels: 3
mode: 16
result size: 499500
metadata: Map()
[Annotation(category, 0, 55, hippopotamus, hippo, river horse, Hippopotamus amphibius, Map(nChannels -> 3, Some(lumbermill, sawmill) -> 7.2882756E-8, Some(beer glass) -> 9.0488925E-8, image -> 0, Some(damselfly) -> 1.9379786E-7, Some(turnstile) -> 6.8434524E-8, Some(cockroach, roach) -> 1.6622849E-7, height -> 333, Some(bulbul) -> 1.6930231E-7, Some(sea snake) -> 8.89582E-8, origin -> images/hippopotamus.JPEG, Some(mixing bowl) -> 1.2995402E-7, mode -> 16, None -> 1.3814622E-7, Some(whippet) -> 3.894023E-8, width -> 500, Some(buckle) -> 1.0061492E-7))]


To send a list of images, we just difine a set of images

In [17]:
images = ["images/bluetick.jpg", "images/palace.JPEG", "images/hen.JPEG"]
annotations_result = light_pipeline.fullAnnotateImage(images)
annotations_result[0].keys()

dict_keys(['image_assembler', 'class'])

In [18]:
for result in annotations_result:
  print(result['class'])

[Annotation(category, 0, 7, bluetick, Map(nChannels -> 3, Some(lumbermill, sawmill) -> 1.3846728E-6, Some(beer glass) -> 1.1807944E-6, image -> 0, Some(damselfly) -> 3.6875622E-7, Some(turnstile) -> 2.023695E-6, Some(cockroach, roach) -> 6.2982855E-7, height -> 500, Some(bulbul) -> 5.417509E-7, Some(sea snake) -> 5.7421556E-7, origin -> images/bluetick.jpg, Some(mixing bowl) -> 5.4001305E-7, mode -> 16, None -> 4.5454306E-7, Some(whippet) -> 1.2101438E-6, width -> 333, Some(buckle) -> 1.1306514E-6))]
[Annotation(category, 0, 5, palace, Map(nChannels -> 3, Some(lumbermill, sawmill) -> 6.3918545E-5, Some(beer glass) -> 8.879939E-6, image -> 0, Some(damselfly) -> 9.565577E-6, Some(turnstile) -> 6.315168E-5, Some(cockroach, roach) -> 1.125408E-5, height -> 334, Some(bulbul) -> 3.321073E-5, Some(sea snake) -> 1.0886038E-5, origin -> images/palace.JPEG, Some(mixing bowl) -> 2.6202975E-5, mode -> 16, None -> 2.6134943E-5, Some(whippet) -> 1.3805137E-5, width -> 500, Some(buckle) -> 3.121459E-

Or we can simply send a directory that contains all the images we need

In [19]:
annotations_result = light_pipeline.fullAnnotateImage("images/")
annotations_result[0].keys()

dict_keys(['image_assembler', 'class'])

In [20]:
for result in annotations_result:
  print(result['class'])

[Annotation(category, 0, 8, Chihuahua, Map(nChannels -> 3, Some(lumbermill, sawmill) -> 2.093878E-7, Some(beer glass) -> 2.577504E-7, image -> 0, Some(damselfly) -> 7.065122E-8, Some(turnstile) -> 5.03293E-7, Some(cockroach, roach) -> 4.3412697E-7, height -> 500, Some(bulbul) -> 3.4132438E-7, Some(sea snake) -> 1.5207818E-6, origin -> images/chihuahua.jpg, Some(mixing bowl) -> 2.208622E-7, mode -> 16, None -> 3.1694032E-7, Some(whippet) -> 1.1523372E-5, width -> 333, Some(buckle) -> 3.492674E-6))]
[Annotation(category, 0, 55, hippopotamus, hippo, river horse, Hippopotamus amphibius, Map(nChannels -> 3, Some(lumbermill, sawmill) -> 7.2882756E-8, Some(beer glass) -> 9.0488925E-8, image -> 0, Some(damselfly) -> 1.9379786E-7, Some(turnstile) -> 6.8434524E-8, Some(cockroach, roach) -> 1.6622849E-7, height -> 333, Some(bulbul) -> 1.6930231E-7, Some(sea snake) -> 8.89582E-8, origin -> images/hippopotamus.JPEG, Some(mixing bowl) -> 1.2995402E-7, mode -> 16, None -> 1.3814622E-7, Some(whippet) 