![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Healthcare/5.1.Spark_OCR_Multi_Modals.ipynb)

# Spark OCR 


## Blogposts and videos

- [How to Setup Spark OCR on UBUNTU - Video](https://www.youtube.com/watch?v=cmt4WIcL0nI)

- [Installing Spark NLP and Spark OCR in air-gapped networks (offline mode)
](https://medium.com/spark-nlp/installing-spark-nlp-and-spark-ocr-in-air-gapped-networks-offline-mode-f42a1ee6b7a8)

- [Table Detection & Extraction in Spark OCR](https://medium.com/spark-nlp/table-detection-extraction-in-spark-ocr-50765c6cedc9)

- [Signature Detection in Spark OCR](https://medium.com/spark-nlp/signature-detection-in-spark-ocr-32f9e6f91e3c)

- [GPU image pre-processing in Spark OCR](https://medium.com/spark-nlp/gpu-image-pre-processing-in-spark-ocr-3-1-0-6fc27560a9bb)

**More examples here**

https://github.com/JohnSnowLabs/spark-ocr-workshop

**Colab Setup**

In [None]:
import json, os
from google.colab import files

if 'spark_jsl.json' not in os.listdir():
  license_keys = files.upload()
  os.rename(list(license_keys.keys())[0], 'spark_ocr.json')

with open('spark_ocr.json') as f:
    license_keys = json.load(f)

# Defining license key-value pairs as local variables
locals().update(license_keys)

In [None]:
# Installing Spark OCR
! pip install spark-ocr==$OCR_VERSION --extra-index-url=https://pypi.johnsnowlabs.com/$SPARK_OCR_SECRET --upgrade

# Installing pyspark and spark-nlp
! pip install --upgrade -q pyspark==3.2.1 spark-nlp==$PUBLIC_VERSION

<b><h1><font color='darkred'>!!! ATTENTION !!! </font><h1><b>

<b>After running previous cell, <font color='darkred'>RESTART the COLAB RUNTIME </font> and go ahead.<b>

In [None]:
import json, os

with open("spark_ocr.json", 'r') as f:
  license_keys = json.load(f)

# Adding license key-value pairs to environment variables
os.environ.update(license_keys)

# Defining license key-value pairs as local variables
locals().update(license_keys)

In [None]:
import sparkocr
import sys
from pyspark.sql import SparkSession
from sparkocr import start
import base64
from sparkocr.transformers import *
from pyspark.ml import PipelineModel
from pyspark.sql import functions as F
from sparkocr.enums import *
from sparkocr.utils import display_images

In [None]:
# Start spark
spark = sparkocr.start(secret=SPARK_OCR_SECRET, 
                       nlp_version=PUBLIC_VERSION
                       )

spark

## Working with tables

### Table Detection

**Load images**

In [None]:
!wget -q raw.githubusercontent.com/JohnSnowLabs/spark-ocr-workshop/master/jupyter/data/tab_images/cTDaR_t10168.jpg -P table_image/
!wget -q raw.githubusercontent.com/JohnSnowLabs/spark-ocr-workshop/master/jupyter/data/tab_images/cTDaR_t10011.jpg -P table_image/

**Read and display images**

In [None]:
from sparkocr.utils import display_tables

In [None]:
image_df = spark.read.format("binaryFile").load("/content/table_image/*.jpg")

display_images(BinaryToImage().transform(image_df), "image")

**Create OCR Pipeline**

In [None]:
binary_to_image = BinaryToImage() 
binary_to_image.setImageType(ImageType.TYPE_3BYTE_BGR)

table_detector = ImageTableDetector.pretrained("general_model_table_detection_v2", "en", "clinical/ocr")
table_detector.setInputCol("image")
table_detector.setOutputCol("table_regions")

draw_regions = ImageDrawRegions()
draw_regions.setInputCol("image")
draw_regions.setInputRegionsCol("table_regions")
draw_regions.setOutputCol("image_with_regions")
draw_regions.setRectColor(Color.red)

pipeline = PipelineModel(stages=[
    binary_to_image,
    table_detector,
    draw_regions
])

**Show results**

In [None]:
result =  pipeline.transform(image_df)
display_images(result, "image_with_regions")

### Table Recognition

In [None]:
from pyspark.sql.functions import desc, row_number, monotonically_increasing_id
from pyspark.sql.window import Window
import pyspark.sql.functions as f

**Load and display the images.**

In [None]:
image_df= spark.read.format("binaryFile").load("table_image")

# add index to the dataframe
image_df_with_seq_id = image_df.withColumn('index', row_number().over(Window.orderBy(monotonically_increasing_id())) - 1)

display_images(BinaryToImage().transform(image_df), "image")

**Create OCR pipeline.**

In [None]:
binary_to_image = BinaryToImage()
# need set image type for correct work TableDetection model
binary_to_image.setImageType(ImageType.TYPE_3BYTE_BGR)

table_detector = ImageTableDetector.pretrained("general_model_table_detection_v2", "en", "clinical/ocr")
table_detector.setInputCol("image")
table_detector.setOutputCol("region")

splitter = ImageSplitRegions()
splitter.setInputCol("image")
splitter.setInputRegionsCol("region")
splitter.setOutputCol("table_image")
splitter.setDropCols("image")
splitter.setImageType(ImageType.TYPE_BYTE_GRAY)

scaler = ImageScaler()
scaler.setInputCol("table_image")
scaler.setOutputCol("scaled_image")
scaler.setScaleFactor(2)

cell_detector = ImageTableCellDetector()
cell_detector.setInputCol("scaled_image")
cell_detector.setOutputCol("cells")
cell_detector.setKeepInput(True)

table_recognition = ImageCellsToTextTable()
table_recognition.setInputCol("scaled_image")
table_recognition.setCellsCol('cells')
table_recognition.setMargin(1)
table_recognition.setStrip(True)
table_recognition.setOutputCol('table')


pipeline = PipelineModel(stages=[
    binary_to_image,
    table_detector,
    splitter,
    scaler,
    cell_detector,
    table_recognition
])

**Run the pipeline and display the tables in the images.**

In [None]:
results = pipeline.transform(image_df_with_seq_id).cache()
display_images(results, "table_image")

**Display recognized tables**

In [None]:
display_tables(results)

### Table Cell Recognition

**Read and display images**

In [None]:
from sparkocr.transformers import *
from sparkocr.enums import *
from sparkocr.utils import display_images, display_table, display_tables
from pyspark.ml import PipelineModel
import pyspark.sql.functions as f

import pkg_resources
test_image_path = pkg_resources.resource_filename('sparkocr', 'resources/ocr/tableImage/table[0,2,4]*.*')
image_df= spark.read.format("binaryFile").load(test_image_path)

display_images(BinaryToImage().transform(image_df), "image")

**Define OCR Pipeline**

In [None]:
binary_to_image = BinaryToImage()
binary_to_image.setImageType(ImageType.TYPE_BYTE_GRAY)
binary_to_image.setInputCol("content")

cell_detector = ImageTableCellDetector()
cell_detector.setInputCol("image")
cell_detector.setOutputCol("cells")
cell_detector.setKeepInput(True)
cell_detector.setAlgoType("morphops")
cell_detector.setDrawDetectedLines(True)

table_recognition = ImageCellsToTextTable()
#table_recognition.setInputCol("output_image")
table_recognition.setInputCol("image")
table_recognition.setCellsCol('cells')
table_recognition.setMargin(2)
table_recognition.setStrip(True)
table_recognition.setOutputCol('table')

pipeline = PipelineModel(stages=[
    binary_to_image,
     cell_detector,
     table_recognition
])

**Run Pipeline**

In [None]:
results = pipeline.transform(image_df).cache()

In [None]:
results= results.na.drop(subset=["cells", "output_image", "table"])

In [None]:
display_images(results, "image")

In [None]:
display_tables(results)

**Show recognized cells for first image**

In [None]:
table_recognition_csv = ImageCellsToTextTable()
table_recognition_csv.setInputCol("image")
table_recognition_csv.setCellsCol('cells')
table_recognition_csv.setMargin(2)
table_recognition_csv.setStrip(True)
table_recognition_csv.setOutputCol('table')
table_recognition_csv.setOutputFormat('csv')

pipeline_csv = PipelineModel(stages=[
    binary_to_image,
     cell_detector,
     table_recognition_csv
])

In [None]:
results_csv = pipeline_csv.transform(image_df).cache()
print(results_csv.select("table").collect()[2].table)

##Handwritten Detection

In [None]:
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-ocr-workshop/master/jupyter/data/handwritten/page1.jpeg

In [None]:
image_df = spark.read.format("binaryFile").load('page1.jpeg')

display_images(BinaryToImage().transform(image_df), "image")

In [None]:
binary_to_image = BinaryToImage()
binary_to_image.setImageType(ImageType.TYPE_3BYTE_BGR)

pretrained_model = ("image_handwritten_detector_gsa0803", "en", "public/ocr/models")

handwritten_detector = ImageHandwrittenDetector() \
    .pretrained(*pretrained_model) \
    .setInputCol("image") \
    .setOutputCol("handwritten_regions") \
    .setScoreThreshold(0.4)

draw_regions = ImageDrawRegions() \
    .setInputCol("image") \
    .setInputRegionsCol("handwritten_regions") \
    .setOutputCol("image_with_regions") \
    .setFontSize(16) \
    .setRectColor(Color.red)

pipeline = PipelineModel(stages=[
    binary_to_image,
    handwritten_detector,
    draw_regions
])

In [None]:
result =  pipeline.transform(image_df).cache()
display_images(result, "image_with_regions")

##Signature Detection

In [None]:
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-ocr-workshop/master/jupyter/data/signature/image_147.png

In [None]:
image_file = 'image_147.png'
image_df = spark.read.format("binaryFile").load(image_file)

display_images(BinaryToImage().transform(image_df), "image")

**Define OCR Pipeline**

In [None]:
binary_to_image = BinaryToImage()
binary_to_image.setImageType(ImageType.TYPE_3BYTE_BGR)

pretrained_model = ("image_handwritten_detector_gsa0628", "en", "public/ocr/models")
signature_detector = ImageHandwrittenDetector() \
    .pretrained(*pretrained_model) \
    .setInputCol("image") \
    .setOutputCol("signature_regions") \
    .setOutputLabels(["signature"]) \
    .setScoreThreshold(0.4)

draw_regions = ImageDrawRegions() \
    .setInputCol("image") \
    .setInputRegionsCol("signature_regions") \
    .setOutputCol("image_with_regions") \
    .setFontSize(16) \
    .setRectColor(Color.red)

pipeline = PipelineModel(stages=[
    binary_to_image,
    signature_detector,
    draw_regions
])

**Run pipeline and show results**

In [None]:
result =  pipeline.transform(image_df).cache()
display_images(result, "image_with_regions")

## Visual Document Ner

In [None]:
from sparkocr.transformers import ImageToHocr, VisualDocumentNer, BinaryToImage
from sparkocr.utils import display_images, display_image

In [None]:
import pkg_resources
test_image_path = pkg_resources.resource_filename('sparkocr', 'resources/ocr/images/SROIE/')
bin_df = spark.read.format("binaryFile").load(test_image_path)
bin_df.show()

In [None]:
image_df = BinaryToImage().transform(bin_df)
display_images(image_df)

In [None]:
binary_to_image = BinaryToImage()\
    .setOutputCol("image")

img_to_hocr = ImageToHocr()\
    .setInputCol("image")\
    .setOutputCol("hocr")\
    .setIgnoreResolution(False)\
    .setOcrParams(["preserve_interword_spaces=0"])

doc_ner = VisualDocumentNer()\
    .pretrained("visual_document_NER_SROIE0526", "en", "public/ocr/models")\
    .setInputCol("hocr")    

# OCR pipeline
ner_pipeline = PipelineModel(stages=[
    binary_to_image,
    img_to_hocr,
    doc_ner
])

In [None]:
results = ner_pipeline.transform(bin_df).cache()

In [None]:
pd_df = results.select('entities').toPandas().explode('entities')

pd_df['label'] = pd_df.entities.apply(lambda a : a[3])
pd_df['chunk'] = pd_df.entities.apply(lambda a : a[4]['word'])
pd_df[pd_df['label'] != "O"][['label', 'chunk']].drop_duplicates()

## Visual Document NER v2

In [None]:
import pkg_resources
test_image_path = pkg_resources.resource_filename('sparkocr', 'resources/ocr/forms/form1.jpg')
bin_df = spark.read.format("binaryFile").load(test_image_path)
bin_df.show()

In [None]:
image_df = BinaryToImage().transform(bin_df)
display_images(image_df)

In [None]:
binary_to_image = BinaryToImage()\
    .setOutputCol("image") \
    .setImageType(ImageType.TYPE_3BYTE_BGR)

img_to_hocr = ImageToHocr()\
    .setInputCol("image")\
    .setOutputCol("hocr")\
    .setIgnoreResolution(False)\
    .setOcrParams(["preserve_interword_spaces=0"])

tokenizer = HocrTokenizer()\
    .setInputCol("hocr")\
    .setOutputCol("token")

doc_ner = VisualDocumentNerV2()\
    .pretrained("layoutlmv2_funsd", "en", "clinical/ocr")\
    .setInputCols(["token", "image"])\
    .setOutputCol("entities")

draw = ImageDrawAnnotations() \
    .setInputCol("image") \
    .setInputChunksCol("entities") \
    .setOutputCol("image_with_annotations") \
    .setFontSize(10) \
    .setLineWidth(4)\
    .setRectColor(Color.red)

# OCR pipeline
pipeline = PipelineModel(stages=[
    binary_to_image,
    img_to_hocr,
    tokenizer,
    doc_ner,
    draw
])

results = pipeline.transform(bin_df).cache()

In [None]:
pd_df = results.select('entities').toPandas().explode('entities')

pd_df['label'] = pd_df.entities.apply(lambda a : a[3])
pd_df['chunk'] = pd_df.entities.apply(lambda a : a[4]['word'])
pd_result = pd_df[pd_df['label'] != "O"][['label', 'chunk']].drop_duplicates()

In [None]:
pd_result

In [None]:
display_images(results, "image_with_annotations", width=1000)

## Visual Document Classifier

In [None]:
import pkg_resources
test_image_path = pkg_resources.resource_filename('sparkocr', 'resources/ocr/images/document_classification/')
bin_df = spark.read.format("binaryFile").load(test_image_path)
bin_df.show()

In [None]:
for item in BinaryToImage().transform(bin_df).select("image").collect():
    display_image(item.image)

In [None]:
binary_to_image = BinaryToImage()\
    .setOutputCol("image")

img_to_hocr = ImageToHocr()\
    .setInputCol("image")\
    .setOutputCol("hocr")\
    .setIgnoreResolution(False)\
    .setOcrParams(["preserve_interword_spaces=0"])

doc_classifier = VisualDocumentClassifier()\
    .pretrained("visual_document_classifier_tobacco3482", "en", "clinical/ocr")\
    .setInputCol("hocr")\
    .setLabelCol("label")\
    .setConfidenceCol("conf")

# OCR pipeline
classifier_pipeline = PipelineModel(stages=[
    binary_to_image,
    img_to_hocr,
    doc_classifier
])

In [None]:
results = classifier_pipeline.transform(bin_df)

In [None]:
import pyspark.sql.functions as f

path_array = f.split(results['path'], '/')
results = results.withColumn('filename', path_array.getItem(f.size(path_array)- 1)) \
       .select("filename", "label", "conf") \
       .show(truncate=False)

**Classes in visual_document_classifier_tobacco3482**
- Advertisement
- Email
- Form
- Letter
- Memo
- Report
- Resume
- Scientific



## LayoutLMv2 for Key Value Pair Extraction

In [None]:
import pkg_resources
test_image_path = pkg_resources.resource_filename('sparkocr', 'resources/ocr/forms/form1.jpg')
bin_df = spark.read.format("binaryFile").load(test_image_path)
bin_df.show()

In [None]:
image_df = BinaryToImage().transform(bin_df)
display_images(image_df)

In [None]:
binary_to_image = BinaryToImage()\
    .setOutputCol("image") \
    .setImageType(ImageType.TYPE_3BYTE_BGR)

img_to_hocr = ImageToHocr()\
    .setInputCol("image")\
    .setOutputCol("hocr")\
    .setIgnoreResolution(False)\
    .setOcrParams(["preserve_interword_spaces=0"])

tokenizer = HocrTokenizer()\
    .setInputCol("hocr")\
    .setOutputCol("token")

doc_ner_cust = VisualDocumentNerV2()\
    .pretrained("layoutlmv2_key_value_pairs", "en", "clinical/ocr")\
    .setInputCols(["token", "image"])\
    .setOutputCol("entities")\
    .setLabels(["other",
                "header",
                "header",
                "key",
                "key",
                "value",
                "value"])\
    .setWhiteList(["header",
                   "key",
                   "value"])

draw = ImageDrawAnnotations() \
    .setInputCol("image") \
    .setInputChunksCol("entities") \
    .setOutputCol("image_with_annotations") \
    .setFontSize(10) \
    .setLineWidth(4)\
    .setRectColor(Color.red)


# OCR pipeline
pipeline = PipelineModel(stages=[
    binary_to_image,
    img_to_hocr,
    tokenizer,
    doc_ner_cust,
    draw
])

results = pipeline.transform(bin_df).cache()
display_images(results, "image_with_annotations", width=1000)

In [None]:
pd_df = results.select('entities').toPandas().explode('entities')

pd_df['label'] = pd_df.entities.apply(lambda a : a[3])
pd_df['chunk'] = pd_df.entities.apply(lambda a : a[4]['word'])
pd_result = pd_df[pd_df['label'] != "O"][['label', 'chunk']].drop_duplicates()
pd_result