# Example of usage Spark OCR with Image Brands Extractions

## Import Spark OCR transformers

## Install spark-ocr python packge
Need specify path to `spark-ocr-assembly-[version].jar` or `secret`

In [None]:
secret = ""
license = ""
version = secret.split("-")[0]
spark_ocr_jar_path = "../../target/scala-2.11"

In [None]:
%%bash
if python -c 'import google.colab' &> /dev/null; then
    echo "Run on Google Colab!"
    echo "Install Open JDK"
    apt-get install -y openjdk-8-jdk-headless -qq > /dev/null
    java -version
fi

In [None]:
import os
import sys

if 'google.colab' in sys.modules:
  os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
  os.environ["PATH"] = os.environ["JAVA_HOME"] + "/bin:" + os.environ["PATH"]

In [None]:
# install from PYPI using secret
%pip install spark-nlp==2.4.5
%pip install spark-ocr==$version --user --extra-index-url=https://pypi.johnsnowlabs.com/$secret --upgrade

In [None]:
# %pip install --user ../dist/spark-ocr-[version].tar.gz

## Initialization of spark session
Need specify path to `spark-ocr-assembly.jar` or `secret`

In [None]:
from sparkocr import start

if license:
    os.environ['JSL_OCR_LICENSE'] = license

spark = start(secret=secret, jar_path=spark_ocr_jar_path, nlp_version="2.4.5")
spark

In [None]:
from pyspark.sql.functions import  col
from pyspark.ml import Pipeline
from sparknlp.base import *
from sparkocr.transformers import *
from sparkocr.enums import *
from termcolor import colored

## Define OCR transformers and pipeline

In [None]:
# Read binary as image
binary_to_image = BinaryToImage()
binary_to_image.setInputCol("content")
binary_to_image.setOutputCol("image")

# Binarize using adaptive tresholding
binarizer = ImageAdaptiveThresholding()
binarizer.setInputCol("image")
binarizer.setOutputCol("binarized_image")

# Apply morphology operation
operation = ImageMorphologyOperation()
operation.setKernelShape(KernelShape.SQUARE)
operation.setKernelSize(2)
operation.setInputCol("binarized_image")
operation.setOutputCol("opening_image")

# Remove small objects
remove_objects = ImageRemoveObjects()
remove_objects.setInputCol("opening_image")
remove_objects.setOutputCol("corrected_image")
remove_objects.setMinSizeFont(48)

# Run tesseract OCR for corrected image
ocr_corrected = ImageToText()
ocr_corrected.setInputCol("corrected_image")
ocr_corrected.setOutputCol("image_brands")
ocr_corrected.setIgnoreResolution(False)
ocr_corrected.setOcrParams(["preserve_interword_spaces=1", ])
ocr_corrected.setPageSegMode(PageSegmentationMode.SINGLE_WORD)
ocr_corrected.setBrandsCoords("""
              [
                 {
                    "name": "name",
                    "rectangle": {
                       "x": 250,
                       "y": 158,
                       "width": 204,
                       "height": 23
                    }
                 },
                 {
                    "name": "issue_date",
                    "rectangle": {
                       "x": 641,
                       "y": 156,
                       "width": 129,
                       "height": 20
                    }
                 },
                 {
                    "name": "serial_number",
                    "rectangle": {
                       "x": 570,
                       "y": 343,
                       "width": 188,
                       "height": 33
                    }
                 }
              ]

""")

# OCR pipeline
pipeline = Pipeline(stages=[
    binary_to_image,
    binarizer,
    operation,
    remove_objects,
    ocr_corrected
])


## Read Image  as binary file

In [None]:
image_path = '././data/images/dollar_bonds/*.jpg'
image_df = spark.read.format("binaryFile").load(image_path)

## Run OCR pipelines

In [None]:
ocr_result = pipeline.fit(image_df).transform(image_df).cache()

## Results

In [None]:
results = ocr_result.select(col("path"),col("image_brands.name.text").alias("name"),col("image_brands.issue_date.text").alias("issue_date") \
                        ,col("image_brands.serial_number.text").alias("serial_number")).collect()
for row in results:
    print(colored("path:\n%s" % row.path, "red"))
    print("Name:\n%s" % row.name)
    print("Issue Date:\n%s" % row.issue_date)
    print("Serial Number:\n%s" % row.serial_number)

