# Example of Fine Tunning ImageToTextV2

This notebook demonstrates a pipeline for Fine Tunning __ImageTextDetectorV2__ text recognition.

## Install spark-ocr python packge
Need to specify:
- secret
- license
- aws credentials

In [1]:
secret = ""
license = ""
AWS_ACCESS_KEY_ID = ""
AWS_SECRET_ACCESS_KEY = ""

version = secret.split("-")[0]
spark_ocr_jar_path = "../../../target/scala-2.11/"

In [2]:
import os
os.environ['OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES'

In [None]:
# install from PYPI using secret
%pip install pillow==9.2.0
%pip install trdg
%pip install spark-nlp==4.0.0
%pip install torch
%pip install transformers==4.16.2
%pip install spark-ocr==$version+spark32 --extra-index-url=https://pypi.johnsnowlabs.com/$secret --upgrade

## Initialization of spark session

In [3]:
from sparkocr import start

if license:
    os.environ['JSL_OCR_LICENSE'] = license

if AWS_ACCESS_KEY_ID:
    os.environ['AWS_ACCESS_KEY'] = AWS_ACCESS_KEY_ID
    os.environ['AWS_SECRET_ACCESS_KEY'] = AWS_SECRET_ACCESS_KEY
    
spark = start(jar_path = "../../../target/scala-2.12")

spark

Spark version: 3.2.1
Spark NLP version: 4.1.0
Spark NLP for Healthcare version: 4.0.0
Spark OCR version: 4.1.0rc3



# Data Generation

In [7]:
import pkg_resources, io
import pyspark.sql.functions as f
from pyspark.ml import PipelineModel
from sparkocr.transformers import *
from sparkocr.enums import *
from sparkocr.utils import display_images

from trdg.generators import GeneratorFromStrings
from trdg.string_generator import create_strings_from_wikipedia


In [9]:
wiki_strings = [s[0:50] for s in create_strings_from_wikipedia(10, 2, 'en')]


generator = GeneratorFromStrings(['2 μL of pGEM®-Tplasmid DNA (0.5 μg/ well) was'] + wiki_strings,
                                         background_type=1,
                                         word_split=True,
                                         count=2)
data = []
for img, text in generator:
    img_byte_arr = io.BytesIO()
    img.save(img_byte_arr, format='PNG')
    img_byte_arr = img_byte_arr.getvalue()
    data.append([img_byte_arr, text])

df = spark.createDataFrame(data=data, schema=["content", "text"])

# Fine Tune the Model

In [11]:
from pyspark.sql.functions import lit

bin_to_image = BinaryToImage()
bin_to_image.setOutputCol("image")
df = bin_to_image.transform(df.withColumn("path", lit("memory")))

ocr = ImageToTextV2().pretrained("ocr_base_printed", "en", "clinical/ocr") \
    .setInputCols(["image"]) \
    .setOutputCol("text") \
    .setUsePandasUdf(False) \
    .setNumTrainEpochs(1)

ocr.fit(df)

result = ocr.transform(image_text_lines_df)
result.select('text').show()
result = result.select('text').collect()

ocr_base_printed download started this may take some time.
Approximate size to download 743.7 MB


12:18:46, INFO Training model.
***** Running training *****
  Num examples = 1
  Num Epochs = 1
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 1


TypeError: type object argument after ** must be a mapping, not Row