## Fine Tuning of Dit Based Visual Document Classifier on Rvl Cdip
### 1. Setup Libraries

In [None]:
!python -m pip install --upgrade spark-ocr==5.0.1 --user --extra-index-url https://pypi.johnsnowlabs.com/SECRET

In [4]:
from sparkocr.transformers import *
from sparkocr.transformers.readers.rvlcdip_reader import RvlCdipReader
from sparkocr import start
import pyspark

### 2. Start Spark session
Define some Spark configs first,

In [1]:
extras = {"spark.driver.maxResultSize":"3500m",
          "spark.kryoserializer.buffer.max": "1000M"}

#### 2.1 Optional: Setup S3 access
If you are hosting your dataset on S3 you will need to access your S3 bucket. Run the following cells to add the right dependencies according to your Spark version.
Skip next cell if using another storage option.

In [5]:
spark_to_aws_hadoop = {"3.0": "2.7.4", "3.1": "3.2.0", "3.2": "3.3.1", "3.3": "3.3.2", "3.4":"3.3.4"}
spark_version = pyspark.__version__[:3]
aws_version = spark_to_aws_hadoop[spark_version]

extras["spark.jars.packages"] = "org.apache.hadoop:hadoop-aws:"+aws_version

In [6]:
spark = start(jar_path="../",
              extra_conf=extras)

spark.sparkContext._jsc.hadoopConfiguration().set("fs.s3a.path.style.access", "true")

Spark version: 3.2.1
Spark NLP version: 5.0.2
Spark OCR version: 5.0.1rc2

:: loading settings :: url = jar:file:/opt/spark/jars/ivy-2.5.0.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /root/.ivy2/cache
The jars for the packages stored in: /root/.ivy2/jars
org.apache.hadoop#hadoop-aws added as a dependency
com.johnsnowlabs.nlp#spark-nlp_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-5e5c0cd2-30e6-4ecd-a8ab-acc30a0139f4;1.0
	confs: [default]
	found org.apache.hadoop#hadoop-aws;3.3.1 in central
	found com.amazonaws#aws-java-sdk-bundle;1.11.901 in central
	found org.wildfly.openssl#wildfly-openssl;1.0.7.Final in central
	found com.johnsnowlabs.nlp#spark-nlp_2.12;5.0.2 in central
	found com.typesafe#config;1.4.2 in central
	found org.rocksdb#rocksdbjni;6.29.5 in central
	found com.github.universal-automata#liblevenshtein;3.0.0 in central
	found com.google.protobuf#protobuf-java-util;3.0.0-beta-3 in central
	found com.google.protobuf#protobuf-java;3.0.0-beta-3 in central
	found com.google.code.gson#gson;2.3 in central
	found it.unimi.dsi#fastutil;7.0.12 in central
	found org.projectlombok#lombok;1.16.

### 3. Load the Training Dataset
Let's use VisualNLP RvlCdipReader's utility functions to lift the RvlCdip training dataset. Let's take a look at the documentation to understand different parameters and options. 

In [5]:
help(RvlCdipReader().readTrainDataset)

Help on method readTrainDataset in module sparkocr.transformers.readers.rvlcdip_reader:

readTrainDataset(spark, labels_path, images_path, partitions=8, storage_level=StorageLevel(True, False, False, False, 1)) method of sparkocr.transformers.readers.rvlcdip_reader.RvlCdipReader instance
    Reads the dataset from an external resource.
    
    Parameters
    ----------
    spark : :class:`pyspark.sql.SparkSession`
        Initiated Spark Session with Spark NLP
    labels_path : str
        The path to the labels file, i.e., labels/train.txt
    images_path : str
        the path where you unzipped the files for RvlCdip Train Images
     partitions : int
        sets the minimum number of partitions for the case of lifting multiple files in parallel into a single dataframe. Defaults to 8.
    storage_level : sets the persistence level according to PySpark definitions. Defaults to StorageLevel.DISK_ONLY.



So, we need to provide 2 paths, one to the labels file, and another one to the images file. In this case, we have previously flattened all the images into a single folder, to make the data access faster in Spark. This is not mandatory, but recommended in filesystems where listing folders is expensive.

In [6]:
labels_path = "s3a://dev.johnsnowlabs.com/ocr/datasets/rvl_cdip_train_labels.txt"
images_path = "s3a://dev.johnsnowlabs.com/ocr/datasets/rvl_cdip_full_/*.tif"

In [None]:
train_df = RvlCdipReader().readTrainDataset(spark, labels_path, images_path)

23/09/15 14:25:41 WARN MetricsConfig: Cannot locate configuration: tried hadoop-metrics2-s3a-file-system.properties,hadoop-metrics2.properties

### 3.1 Optional: create a smaller local copy to play with different paramers.
You don't need to work with the entire dataset all the time, and until you properly set up all the parameters for fine-tuning, it is recommended to work on a smaller subset of your data.

In [7]:
train_df.limit(400).write.parquet("train_df")
train_df = spark.read.parquet("./train_df")

                                                                                

In [8]:
train_df.printSchema()

root
 |-- path: string (nullable = true)
 |-- modificationTime: timestamp (nullable = true)
 |-- length: long (nullable = true)
 |-- content: binary (nullable = true)
 |-- act_label: string (nullable = true)



In [9]:
train_df.count()

400

Let's take a look at the different labels present in the dataset

In [10]:
labels = RvlCdipReader()._labels
labels

['resume',
 'handwritten',
 'memo',
 'email',
 'questionnaire',
 'scientific_report',
 'invoice',
 'advertisement',
 'news_article',
 'form',
 'scientific_publication',
 'file_folder',
 'budget',
 'specification',
 'presentation',
 'letter']

### 4. Training Pipeline and Model
In this section, we will define the necessary transformers to use in the fine-tuning process. </br>
BinaryToImage: will convert the binary content into an image structure containing information about the image such as resolution, channels, etc.<br>
VisualDocumentaClassifierV3: This one is our model. We're defining training parameters like the number of epochs, the labels, or the batch size.<br>

In [11]:
binary_to_image = BinaryToImage().\
    setOutputCol("image")

classifier = VisualDocumentClassifierV3() \
            .setInputCols(["image"]) \
            .setOutputCol("entities") \
            .setInputCols(["image", "act_label"]) \
            .setOutputCol("entities") \
            .setLabels(labels) \
            .setTrainBatchSize(32) \
            .setTrainEpochs(1)

Let's add the 'image' column to the training data

In [12]:
images_labels = binary_to_image.transform(train_df)
images_labels.printSchema()

root
 |-- image: struct (nullable = true)
 |    |-- origin: string (nullable = true)
 |    |-- height: integer (nullable = false)
 |    |-- width: integer (nullable = false)
 |    |-- nChannels: integer (nullable = false)
 |    |-- mode: integer (nullable = false)
 |    |-- resolution: integer (nullable = false)
 |    |-- data: binary (nullable = true)
 |-- exception: string (nullable = true)
 |-- path: string (nullable = true)
 |-- modificationTime: timestamp (nullable = true)
 |-- length: long (nullable = true)
 |-- act_label: string (nullable = true)
 |-- pagenum: integer (nullable = true)



### 4.1 Optional: dataset cleanup
Sometimes encoding issues appear on some files, we can just delete them.

In [None]:
images_labels.filter(images_labels["image"].isNull() == True).select("path").show(truncate=False)

### 5.1 Trigger fine-tuning and save model

In [13]:
fitted_classifier = classifier.fit(images_labels,
                          base_model="dit-base-224-p16-500k-62d53a.pth",
                          validate=True)

/root/cache_pretrained/ocr/base/dit-base-224-p16-500k-62d53a.pth


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
17:39:38, INFO Patch size = (16, 16)
17:39:38, INFO Load ckpt from /root/cache_pretrained/ocr/base/dit-base-224-p16-500k-62d53a.pth
17:39:38, INFO Load state_dict by model_key = model


Weights of Beit not initialized from pretrained model: ['blocks.0.attn.relative_position_bias_table', 'blocks.1.attn.relative_position_bias_table', 'blocks.2.attn.relative_position_bias_table', 'blocks.3.attn.relative_position_bias_table', 'blocks.4.attn.relative_position_bias_table', 'blocks.5.attn.relative_position_bias_table', 'blocks.6.attn.relative_position_bias_table', 'blocks.7.attn.relative_position_bias_table', 'blocks.8.attn.relative_position_bias_table', 'blocks.9.attn.relative_position_bias_table', 'blocks.10.attn.relative_position_bias_table', 'blocks.11.attn.relative_position_bias_table']
Weights from pretrained model not used in Beit: ['pos_embed']
Ignored weights of Beit not initialized from pretrained model: ['blocks.0.attn.relative_position_index', 'blocks.1.attn.relative_position_index', 'blocks.2.attn.relative_position_index', 'blocks.3.attn.relative_position_index', 'blocks.4.attn.relative_position_index', 'blocks.5.attn.relative_position_index', 'blocks.6.attn.rel

17:39:39, INFO Model = Beit(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=False)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop2): Dropout(p=0.0, inplace=False)
      )
    )
    (1-11): 11 x Bl

LR = 0.00050000
Batch size = 64
Update frequent = 2
Number of training examples = 253
Number of training training per epoch = 3
Assigned values = [0.023757264018058777, 0.03167635202407837, 0.04223513603210449, 0.056313514709472656, 0.07508468627929688, 0.1001129150390625, 0.13348388671875, 0.177978515625, 0.2373046875, 0.31640625, 0.421875, 0.5625, 0.75, 1.0]
Param groups = {
  "layer_0_no_decay": {
    "weight_decay": 0.0,
    "params": [
      "cls_token",
      "patch_embed.proj.bias"
    ],
    "lr_scale": 0.023757264018058777
  },
  "layer_0_decay": {
    "weight_decay": 0.05,
    "params": [
      "patch_embed.proj.weight"
    ],
    "lr_scale": 0.023757264018058777
  },
  "layer_1_no_decay": {
    "weight_decay": 0.0,
    "params": [
      "blocks.0.gamma_1",
      "blocks.0.gamma_2",
      "blocks.0.norm1.weight",
      "blocks.0.norm1.bias",
      "blocks.0.attn.q_bias",
      "blocks.0.attn.v_bias",
      "blocks.0.attn.relative_position_bias_table",
      "blocks.0.attn.pro

                                                                                

Epoch: [0]  [0/7]  eta: 0:00:54  lr: 0.000000  min_lr: 0.000000  loss: 3.4996 (3.4996)  class_acc: 0.1875 (0.1875)  loss_scale: 65536.0000 (65536.0000)  weight_decay: 0.0500 (0.0500)  time: 7.7264  data: 4.4407  max mem: 3878


                                                                                

Averaged stats: lr: 0.000017  min_lr: 0.000000  loss: 3.2738 (3.4032)  class_acc: 0.0938 (0.1250)  loss_scale: 32768.0000 (49152.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 26.2609 (inf)


                                                                                

Test:  [0/4]  eta: 0:00:14  loss: 3.7177 (3.7177)  acc1: 8.3333 (8.3333)  acc5: 56.2500 (56.2500)  time: 3.5353  data: 3.4671  max mem: 4874


                                                                                

Test:  [3/4]  eta: 0:00:02  loss: 3.1346 (3.2899)  acc1: 8.3333 (16.1458)  acc5: 56.2500 (57.2917)  time: 2.1658  data: 2.0983  max mem: 4874


[Stage 80:>                                                         (0 + 1) / 1]

* Acc@1 16.146 Acc@5 57.292 loss 3.290
Accuracy of the network on the 147 test images: 16.1%
Max accuracy: 16.15%
Training time 0:00:45


17:40:57, INFO Export to onnx...
  assert condition, message
17:41:00, INFO Model exported to: /tmp/tmpwov7qqu3onnx_tmp/model.onxx
17:41:00, INFO Storing model.


verbose: False, log level: Level.ERROR



17:41:00, INFO Removing cache...
17:41:01, INFO Failed to remove cache...


In [14]:
fitted_classifier.save("./fitted_model_best")

                                                                                