# Visual Document Classifier LiLT training

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-ocr-workshop/blob/master/jupyter/VisualDocumentClassifierTraining/SparkOCRVisualDocumentClassifierLiltTraining.ipynb)

## Set license and AWS keys

Need to specify:
- secret
- license
- aws credentials

### Option #1 - define in this cell

In [1]:
import os

secret = ""
version = secret.split("-")[0]

os.environ['SPARK_OCR_LICENSE'] = ""
os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""


### Option #2 - provide spark_ocr.json file

In [2]:
import json, os
import sys

if 'google.colab' in sys.modules:
    from google.colab import files

    if 'spark_ocr.json' not in os.listdir():
      license_keys = files.upload()
      os.rename(list(license_keys.keys())[0], 'spark_ocr.json')

with open('spark_ocr.json') as f:
    license_keys = json.load(f)

# Defining license key-value pairs as local variables
locals().update(license_keys)

## Install Spark-OCR

It is needed only in case of colab. For other environment you should prepare environment appropriately.

In [3]:
# Installing Dependencies
%!pip install -q git+https://github.com/JohnSnowLabs/transformers.git@LiltOnnx
%pip install spark-ocr==$version --extra-index-url=https://pypi.johnsnowlabs.com/$secret --upgrade

## Download demo datasets

Here we downloaded demo set. You need to put your images to one folder and prepare labelling txt file as at example.</br>
Instructions here are for the command line, you can also manually download and unzip these files.

In [4]:
!wget https://s3.amazonaws.com/dev.johnsnowlabs.com/ocr/test_models/LiLT_vocabulary.txt
!wget https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/datasets/visual_doc_classifier/rvl_cdip_tmp.zip
!unzip rvl_cdip_tmp.zip

## Start Spark session with Spark OCR

In [5]:
from sparkocr import start
from pyspark import SparkConf

spark_ocr_jar_path = "../../target/scala-2.12/"
spark = start(jar_path = spark_ocr_jar_path)

spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.enabled", "true")

spark

Spark version: 3.2.2
Spark NLP version: 4.3.1
Spark NLP for Healthcare version: 4.3.1
Spark OCR version: 5.0.0



### Define labels

In [6]:
labels = ["advertisement",
          "budget",
          "email",
          "file_folder",
          "form",
          "handwritten",
          "invoice",
          "letter",
          "memo",
          "news_article",
          "presentation",
          "questionnaire",
          "resume",
          "scientific_publication",
          "scientific_report",
          "specification"]

Images for classification should be placed in one folder ("./rvl_cdip_tmp" in this case)

Labels file should be placed to the same folder. File format is the following. One row - one record, file_path and label separated by space like,

```
file1.jpg 1
file2.jpg 2
```

In [None]:
from sparkocr.transformers import *

df = DatasetReader.readDataset("./rvl_cdip_tmp", spark)
display(df.select("content", "act_label").limit(1))

### Repartition your data
To better leverage your cluster you may need repartitioning of your input dataframe

In [8]:
df = df.repartition(8)

In [9]:
from sparkocr.transformers import *
from sparkocr.enums import *
from pyspark.ml import PipelineModel

binary_to_image = BinaryToImage()\
    .setOutputCol("image") \
    .setImageType(ImageType.TYPE_3BYTE_BGR)

img_to_hocr = ImageToHocr()\
    .setInputCol("image")\
    .setOutputCol("hocr")\
    .setIgnoreResolution(False)\
    .setOcrParams(["preserve_interword_spaces=0"])

# OCR pipeline
pipeline1 = PipelineModel(stages=[
    binary_to_image,
    img_to_hocr,
])

df = pipeline1.transform(df).cache()
df = df.withColumnRenamed("image", "orig_image")
display(df.select("act_label", "pagenum", "exception", "hocr"))

act_label,pagenum,exception,hocr
4,0,,<div class='ocr...


In [None]:
from sparkocr.utils import get_vocabulary_dict

vocab_file = "LiLT_vocabulary.txt"
vocab = get_vocabulary_dict(vocab_file, ",")

trainer = VisualDocumentClassifierLilt() \
    .setInputCol("hocr") \
    .setOutputCol("label") \
    .set_train_param_num_train_epochs(2) \
    .set_train_param_useGPU(False) \
    .setLabels(labels)
trainer.setVocabulary(vocab)
doc_class = trainer.fit(df)


In [14]:
res = doc_class \
.setOutputCol("label") \
.transform(df)
display(res.select("label", "pagenum", "exception", "hocr"))

label,pagenum,exception,hocr
form,0,,<div class='ocr...
