![JohnSnowLabs](https://sparknlp.org/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp/blob/master/examples/python/annotation/image/VisionEncoderDecoderForImageCaptioning.ipynb)

In [None]:
! wget -q http://setup.johnsnowlabs.com/colab.sh -O - | bash

## VisionEncoderDecoderForImageCaptioning Annotator

In this notebok we are going to generate captions for images using spark-nlp. It uses the vision transformer ViT to encode the images and then GPT2 to generate tokens. This model is rather heavy so make sure you have enough RAM and possible use an accelerator such as a GPU.

### Downloading Images

In [None]:
!wget -q https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/images/images.zip

In [None]:
import shutil
shutil.unpack_archive("images.zip", "images", "zip")

### Start Spark Session

In [None]:
import sparknlp
from sparknlp.base import *
from sparknlp.annotator import *
from pyspark.sql import SparkSession

spark = sparknlp.start()

In [None]:
data_df = spark.read.format("image").option("dropInvalid", value = True).load(path="images/images/")

### Pipeline with VisionEncoderDecoderForImageCaptioning

In [None]:
image_assembler = ImageAssembler() \
    .setInputCol("image") \
    .setOutputCol("image_assembler")

image_captioning = VisionEncoderDecoderForImageCaptioning \
    .pretrained() \
    .setInputCols(["image_assembler"]) \
    .setOutputCol("caption")

pipeline = Pipeline(stages=[
    image_assembler,
    image_captioning,
])

In [None]:
model = pipeline.fit(data_df)
image_df = model.transform(data_df)
image_df \
    .selectExpr("reverse(split(image.origin, '/'))[0] as image_name", "caption.result") \
    .show(truncate = False)

+-----------------+---------------------------------------------------------+
|image_name       |result                                                   |
+-----------------+---------------------------------------------------------+
|palace.JPEG      |[a large room filled with furniture and a large window]  |
|egyptian_cat.jpeg|[a cat laying on a couch next to another cat]            |
|hippopotamus.JPEG|[a brown bear in a body of water]                        |
|hen.JPEG         |[a flock of chickens standing next to each other]        |
|ostrich.JPEG     |[a large bird standing on top of a lush green field]     |
|junco.JPEG       |[a small bird standing on a wet ground]                  |
|bluetick.jpg     |[a small dog standing on a wooden floor]                 |
|chihuahua.jpg    |[a small brown dog wearing a blue sweater]               |
|tractor.JPEG     |[a man is standing in a field with a tractor]            |
|ox.JPEG          |[a large brown cow standing on top of a lush 

### Light Pipeline

To use the annotator in a light pipeline, we need to use the new method `fullAnnotateImage`, which can receive 3 kinds of input:
1. A path to a single image
2. A path to a list of images

In [None]:
light_pipeline = LightPipeline(model)
annotations_result = light_pipeline.fullAnnotateImage("images/images/hippopotamus.JPEG")
annotations_result[0].keys()

dict_keys(['image_assembler', 'caption'])

To process a list of images, we just pass a list of images.

In [None]:
images = ["images/images/bluetick.jpg", "images/images/palace.JPEG", "images/images/hen.JPEG"]
annotations_result = light_pipeline.fullAnnotateImage(images)
annotations_result[0].keys()

dict_keys(['image_assembler', 'caption'])

In [None]:
for result in annotations_result:
  print(result['caption'])

[Annotation(document, 0, 37, a small dog standing on a wooden floor, Map(nChannels -> 3, image -> 0, height -> 500, origin -> images/images/bluetick.jpg, mode -> 16, width -> 333), [])]
[Annotation(document, 0, 52, a large room filled with furniture and a large window, Map(nChannels -> 3, image -> 0, height -> 334, origin -> images/images/palace.JPEG, mode -> 16, width -> 500), [])]
[Annotation(document, 0, 46, a flock of chickens standing next to each other, Map(nChannels -> 3, image -> 0, height -> 375, origin -> images/images/hen.JPEG, mode -> 16, width -> 500), [])]
