# Visual Document Classifier v2 training

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-ocr-workshop/blob/master/jupyter/VisualDocumentClassifierTraining/SparkOCRVisualDocumentClassifierv2Training.ipynb)

## Set license and AWS keys

Need to specify:
- secret
- license
- aws credentials

### Option #1 - define in this cell

In [2]:
import os

secret = ""
version = secret.split("-")[0]

os.environ['JSL_OCR_LICENSE'] = ""
os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""

### Option #2 - provide spark_ocr.json file

In [None]:
import json, os
import sys

if 'google.colab' in sys.modules:
    from google.colab import files

    if 'spark_ocr.json' not in os.listdir():
      license_keys = files.upload()
      os.rename(list(license_keys.keys())[0], 'spark_ocr.json')

with open('spark_ocr.json') as f:
    license_keys = json.load(f)

# Defining license key-value pairs as local variables
locals().update(license_keys)

## Install Spark-OCR

It is needed only in case of colab. For other environment you should prepare environment appropriately.

In [None]:
%pip install spark-ocr==$version+spark32 --extra-index-url=https://pypi.johnsnowlabs.com/$secret --upgrade

## Prepare dataset

Here we downloaded demo set. You need to put your images to one folder and prepare labelling txt file as at example.

In [3]:
!wget https://www.dropbox.com/s/4oc8gk6ermmf2lg/LayoutLM.v2.voc.txt
!wget https://www.dropbox.com/s/6w2acrrj36wezuw/RVL_CDIP_demo_set.zip
!unzip RVL_CDIP_demo_set.zip

--2022-08-16 12:58:51--  https://www.dropbox.com/s/6w2acrrj36wezuw/RVL_CDIP_demo_set.zip
Resolving www.dropbox.com (www.dropbox.com)... 162.125.70.18, 2620:100:6022:18::a27d:4212
Connecting to www.dropbox.com (www.dropbox.com)|162.125.70.18|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /s/raw/6w2acrrj36wezuw/RVL_CDIP_demo_set.zip [following]
--2022-08-16 12:58:51--  https://www.dropbox.com/s/raw/6w2acrrj36wezuw/RVL_CDIP_demo_set.zip
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc09f872cc57a9f216e8e111b5da.dl.dropboxusercontent.com/cd/0/inline/BrGdi4_JAHM5-T21B9nRVu8gIzAbK4BEG_1UzG1go-YTJo2Y2bl9QN3YXJK9qIuEWxKgXgbnL-DYPzUeRfw391hfp5z2FHA08UnpQSvyXVmFKhlDCRQoAPjbezlDT8vVOWEs46UZ_SCz2VKqBP8CyFqicyqTitJ0U_8U1uCnId941Q/file# [following]
--2022-08-16 12:58:52--  https://uc09f872cc57a9f216e8e111b5da.dl.dropboxusercontent.com/cd/0/inline/BrGdi4_JAHM5-T21B9nRVu8gIzAbK4BEG_1UzG

## Start Spark session with Spark OCR

In [3]:
from sparkocr import start
from pyspark import SparkConf

spark_ocr_jar_path = "../../../spark-ocr/target/scala-2.12/"
spark = start(jar_path = spark_ocr_jar_path)

spark

Spark version: 3.2.0
Spark NLP version: 3.4.1
Spark NLP for Healthcare version: 3.3.2
Spark OCR version: 4.0.2rc1



## Preprocessing data

In [5]:
imagePath = "RVL_CDIP_demo_set/*.tif"
df = spark.read.format("binaryFile").load(imagePath)
print(df.count())

15


In [6]:
label_names = {0: "letter",
               1: "form",
               2: "email",
               3: "handwritten",
               4: "advertisement",
               5: "scientific_report",
               6: "scientific_publication",
               7: "specification",
               8: "file_folder",
               9: "news_article",
               10: "budget",
               11: "invoice",
               12: "presentation",
               13: "questionnaire",
               14: "resume",
               15: "memo"
}

files_labelled = {}
with open("RVL_CDIP_demo_set/demo_set_labels.txt") as file:
    lines = file.readlines()
    for l in lines:
      l_ = l.strip().split(" ")
      head, tail = os.path.split(l_[0])
      files_labelled[tail] = label_names[int(l_[1])]

In [7]:
from pyspark.sql.functions import udf

def get_label(fl):
  head, fname = os.path.split(fl)
  if fname in files_labelled:
    return files_labelled[fname]
  else:
    print("File is missed:", fname)
    return None

get_label_udf = udf(get_label)

df = df.withColumn("act_label", get_label_udf("path"))
df = df.dropna(subset="act_label")
df.show(5)

+--------------------+-------------------+------+--------------------+--------------------+
|                path|   modificationTime|length|             content|           act_label|
+--------------------+-------------------+------+--------------------+--------------------+
|file:/home/alexan...|2020-06-11 10:25:16|225292|[49 49 2A 00 36 6...|       questionnaire|
|file:/home/alexan...|2020-06-11 10:23:36|222494|[49 49 2A 00 48 6...|        news_article|
|file:/home/alexan...|2020-06-11 10:17:28|186252|[49 49 2A 00 B6 D...|       advertisement|
|file:/home/alexan...|2020-06-11 10:26:44|160494|[49 49 2A 00 18 7...|scientific_public...|
|file:/home/alexan...|2020-06-11 10:22:58|152014|[49 49 2A 00 F8 4...|                memo|
+--------------------+-------------------+------+--------------------+--------------------+
only showing top 5 rows



In [34]:
df = df.repartition(7)
df.rdd.getNumPartitions()

7

In [8]:
from sparkocr.transformers import *
from sparkocr.enums import *
from pyspark.ml import PipelineModel

binary_to_image = BinaryToImage()\
    .setOutputCol("image") \
    .setImageType(ImageType.TYPE_3BYTE_BGR)

img_to_hocr = ImageToHocr()\
    .setInputCol("image")\
    .setOutputCol("hocr")\
    .setIgnoreResolution(False)\
    .setOcrParams(["preserve_interword_spaces=0"])

tokenizer = HocrTokenizer()\
    .setInputCol("hocr")\
    .setOutputCol("token")

# OCR pipeline
pipeline1 = PipelineModel(stages=[
    binary_to_image,
    img_to_hocr,
    tokenizer
])

df = pipeline1.transform(df).cache()
df = df.withColumnRenamed("image", "orig_image")
display(df)

orig_image,path,modificationTime,length,act_label,pagenum,exception,hocr,token
{file:/home/alexa...,file:/home/alexan...,2020-06-11 10:25:16,225292,questionnaire,0,,<div class='ocr...,"[{token, 0, 3, co..."
{file:/home/alexa...,file:/home/alexan...,2020-06-11 10:23:36,222494,news_article,0,,<div class='ocr...,"[{token, 0, 0, %,..."
{file:/home/alexa...,file:/home/alexan...,2020-06-11 10:17:28,186252,advertisement,0,,<div class='ocr...,"[{token, 2, 2, [,..."
{file:/home/alexa...,file:/home/alexan...,2020-06-11 10:26:44,160494,scientific_public...,0,,<div class='ocr...,"[{token, 0, 2, wo..."
{file:/home/alexa...,file:/home/alexan...,2020-06-11 10:22:58,152014,memo,0,,<div class='ocr...,"[{token, 2, 3, tr..."
{file:/home/alexa...,file:/home/alexan...,2020-06-11 10:20:32,122628,form,0,,<div class='ocr...,"[{token, 4, 11, c..."
{file:/home/alexa...,file:/home/alexan...,2020-06-11 10:27:46,111966,scientific_report,0,,<div class='ocr...,"[{token, 3, 7, ma..."
{file:/home/alexa...,file:/home/alexan...,2020-06-11 10:19:50,105638,email,0,,<div class='ocr...,"[{token, 0, 9, bw..."
{file:/home/alexa...,file:/home/alexan...,2020-06-11 10:25:56,91164,resume,0,,<div class='ocr...,"[{token, 0, 1, ot..."
{file:/home/alexa...,file:/home/alexan...,2020-06-11 10:28:20,84974,specification,0,,<div class='ocr...,"[{token, 2, 9, ma..."


In [10]:
from sparkocr.utils import get_vocabulary_dict

vocab_file = "LayoutLM.v2.voc.txt"
vocab = get_vocabulary_dict(vocab_file, ",")

doc_class = VisualDocumentClassifierV2() \
    .setInputCols(["token", "orig_image"]) \
    .setOutputCol("label")
doc_class.setVocabulary(vocab)

result = doc_class.getPreprocessedDataset(
  df,
  [1,3,224,224]
  ).cache()

invalid literal for int() with base 10: '' ##,,29623
 ['##', '', '29623']


In [11]:
result.select("path", "input_ids", "bbox", "image", "attention_mask", "token_type_ids", "act_label").write.parquet("preprocessed_dataset")

## Training

In [None]:
from sparkocr.transformers import *

doc_class = VisualDocumentClassifierV2()
doc_class.fit("preprocessed_dataset", model_save_path = "new_model", CHUNK_SIZE = 2, num_train_epochs = 2)