# Visual Document Classifier v2 training

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-ocr-workshop/blob/master/jupyter/VisualDocumentClassifierTraining/SparkOCRVisualDocumentClassifierv2Training.ipynb)

## Set license and AWS keys

Need to specify:
- secret
- license
- aws credentials

### Option #1 - define in this cell

In [1]:
import os

secret = ""
version = secret.split("-")[0]

os.environ['JSL_OCR_LICENSE'] = ""
os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""

### Option #2 - provide spark_ocr.json file

In [None]:
import json, os
import sys

if 'google.colab' in sys.modules:
    from google.colab import files

    if 'spark_ocr.json' not in os.listdir():
      license_keys = files.upload()
      os.rename(list(license_keys.keys())[0], 'spark_ocr.json')

with open('spark_ocr.json') as f:
    license_keys = json.load(f)

# Defining license key-value pairs as local variables
locals().update(license_keys)

## Install Spark-OCR

It is needed only in case of colab. For other environment you should prepare environment appropriately.

In [None]:
%pip install --upgrade git+https://github.com/JohnSnowLabs/transformers.git@layoutlmv2_onnx
%pip install spark-ocr==$version+spark32 --extra-index-url=https://pypi.johnsnowlabs.com/$secret --upgrade

## Prepare dataset

Here we downloaded demo set. You need to put your images to one folder and prepare labelling txt file as at example.

In [3]:
!wget https://www.dropbox.com/s/4oc8gk6ermmf2lg/LayoutLM.v2.voc.txt

--2022-09-04 19:54:30--  https://www.dropbox.com/s/4oc8gk6ermmf2lg/LayoutLM.v2.voc.txt
Resolving www.dropbox.com (www.dropbox.com)... 162.125.71.18, 2620:100:6026:18::a27d:4612
Connecting to www.dropbox.com (www.dropbox.com)|162.125.71.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: /s/raw/4oc8gk6ermmf2lg/LayoutLM.v2.voc.txt [following]
--2022-09-04 19:54:31--  https://www.dropbox.com/s/raw/4oc8gk6ermmf2lg/LayoutLM.v2.voc.txt
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc3c670b526a07af14f749048f24.dl.dropboxusercontent.com/cd/0/inline/BsWlGQ2SsRnNmg7a-HyCWVrEDapp0cZRwwzQ_mPSbrBPcWbo_vJMktz8Oh0B2JioaZFR7tI8NVCDGn5AOyVSQXcQHMrS8aRYUMd1dSa6MRUw8qWc2tCcEcVKtLJkX6H4ZcXGpyaPcuwQ-F3_mdCnb1NUEpF0Mq1O2Q5FmxWZ2x2QkQ/file# [following]
--2022-09-04 19:54:31--  https://uc3c670b526a07af14f749048f24.dl.dropboxusercontent.com/cd/0/inline/BsWlGQ2SsRnNmg7a-HyCWVrEDapp0cZRwwzQ_mPSbrBPcWbo_vJMktz8

## Start Spark session with Spark OCR

In [5]:
from sparkocr import start
from pyspark import SparkConf

spark_ocr_jar_path = "../../../spark-ocr/target/scala-2.12/"
spark = start(jar_path = spark_ocr_jar_path)

spark

Spark version: 3.2.0
Spark NLP version: 3.4.1
Spark NLP for Healthcare version: 3.3.2
Spark OCR version: 4.0.2rc1



## Preprocessing data

In [7]:
from sparkocr.transformers import *

df = VisualDocumentClassifierV2.loadDataset("rvl_cdip_tmp", spark)
df.select("content", "act_label").show(1)



+--------------------+-------------+
|             content|    act_label|
+--------------------+-------------+
|[49 49 2A 00 34 E...|advertisement|
+--------------------+-------------+



In [7]:
df = df.repartition(7)
df.rdd.getNumPartitions()

7

In [6]:
from sparkocr.transformers import *
from sparkocr.enums import *
from pyspark.ml import PipelineModel

binary_to_image = BinaryToImage()\
    .setOutputCol("image") \
    .setImageType(ImageType.TYPE_3BYTE_BGR)

img_to_hocr = ImageToHocr()\
    .setInputCol("image")\
    .setOutputCol("hocr")\
    .setIgnoreResolution(False)\
    .setOcrParams(["preserve_interword_spaces=0"])

tokenizer = HocrTokenizer()\
    .setInputCol("hocr")\
    .setOutputCol("token")

# OCR pipeline
pipeline1 = PipelineModel(stages=[
    binary_to_image,
    img_to_hocr,
    tokenizer
])

df = pipeline1.transform(df).cache()
df = df.withColumnRenamed("image", "orig_image")
display(df.select("act_label", "pagenum", "exception", "hocr", "token"))

act_label,pagenum,exception,hocr,token
questionnaire,0,,<div class='ocr...,"[{token, 0, 3, co..."
news_article,0,,<div class='ocr...,"[{token, 0, 0, %,..."
advertisement,0,,<div class='ocr...,"[{token, 2, 2, [,..."
scientific_public...,0,,<div class='ocr...,"[{token, 0, 2, wo..."
memo,0,,<div class='ocr...,"[{token, 2, 3, tr..."
form,0,,<div class='ocr...,"[{token, 4, 11, c..."
scientific_report,0,,<div class='ocr...,"[{token, 3, 7, ma..."
email,0,,<div class='ocr...,"[{token, 0, 9, bw..."
resume,0,,<div class='ocr...,"[{token, 0, 1, ot..."
specification,0,,<div class='ocr...,"[{token, 2, 9, ma..."


In [7]:
from sparkocr.utils import get_vocabulary_dict

vocab_file = "LayoutLM.v2.voc.txt"
vocab = get_vocabulary_dict(vocab_file, ",")

doc_class = VisualDocumentClassifierV2() \
    .setInputCols(["token", "orig_image"]) \
    .setOutputCol("label")
doc_class.setVocabulary(vocab)

result = doc_class.getPreprocessedDataset(
  df,
  [1,3,224,224]
  ).cache()

In [9]:
result.select("path", "input_ids", "bbox", "image", "attention_mask", "token_type_ids", "act_label").write.parquet("preprocessed_dataset")

## Preprocessed datasets

Some of datasets are available in preprocessed state

In [11]:
df = VisualDocumentClassifierV2.loadPreprocessedDataset("rvl_cdip_tmp", spark)
df.show(1)

+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+
|           input_ids|                bbox|               image|      attention_mask|      token_type_ids|        act_label|
+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+
|[101, 13169, 1051...|[0, 0, 0, 0, 42, ...|[255, 255, 255, 2...|[1.0, 1.0, 1.0, 1...|[0, 0, 0, 0, 0, 0...|scientific_report|
+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+
only showing top 1 row



## Training

In [None]:
from sparkocr.transformers import *

doc_class = VisualDocumentClassifierV2().fit("preprocessed_dataset",                                             
                                              model_save_path = "new_model",
                                              vocab_path = "LayoutLM.v2.voc.txt",
                                              spark = spark,
                                              CHUNK_SIZE = 2,
                                              num_train_epochs = 2,
                                              useGPU = False,
                                              do_validation = True)