<a href="https://colab.research.google.com/github/arthurflor23/text-correction/blob/master/src/tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<img src="https://github.com/arthurflor23/text-correction/blob/master/doc/image/header.png?raw=true">

# Text Correction using TensorFlow 2.0

This tutorial shows how you can use the project [Text Correction](https://github.com/arthurflor23/text-corretion) in your Google Colab.



## 1 Localhost Environment

We'll make sure you have the project in your Google Drive with the datasets folders. If you already have structured files in the cloud, skip this step.

### 1.1 Datasets

The datasets that you can use:

a. [BEA2019](https://www.cl.cam.ac.uk/research/nl/bea2019st/)

b. [Bentham](http://transcriptorium.eu/datasets/bentham-collection/)

c. [CoNLL13](https://www.comp.nus.edu.sg/~nlp/conll13st.html)

d. [CoNLL14](https://www.comp.nus.edu.sg/~nlp/conll14st.html)

e. [Google](https://ai.google/research/pubs/pub41880)

f. [IAM](http://www.fki.inf.unibe.ch/databases/iam-handwriting-database)

g. [Rimes](http://www.a2ialab.com/doku.php?id=rimes_database:start)

h. [Washington](http://www.fki.inf.unibe.ch/databases/iam-historical-document-database/washington-database)

### 1.2 Raw folder

On localhost, download the code project from GitHub and extract the chosen dataset (or all if you prefer) in the **raw** folder. Don't change anything of the structure of the dataset, since the scripts were made from the **original structure** of them. Your project directory will be like this:

```
.
├── raw
│   ├── bea2019
│   │   ├── json
│   │   ├── json_to_m2.py
│   │   ├── licence.wi.txt
│   │   ├── license.locness.txt
│   │   ├── m2
│   │   └── readme.txt
│   ├── bentham
│   │   ├── BenthamDatasetR0-GT
│   │   └── BenthamDatasetR0-Images
│   ├── conll13
│   │   ├── m2scorer
│   │   ├── original
│   │   ├── README
│   │   ├── revised
│   │   └── scripts
│   ├── conll14
│   │   ├── alt
│   │   ├── noalt
│   │   ├── README
│   │   └── scripts
│   ├── google
│   │   ├── europarl-v6.cs
│   │   ├── europarl-v6.de
│   │   ├── europarl-v6.en
│   │   ├── europarl-v6.es
│   │   ├── europarl-v6.fr
│   │   ├── news.2007.cs.shuffled
│   │   ├── news.2007.de.shuffled
│   │   ├── news.2007.en.shuffled
│   │   ├── news.2007.es.shuffled
│   │   ├── news.2007.fr.shuffled
│   │   ├── news.2008.cs.shuffled
│   │   ├── news.2008.de.shuffled
│   │   ├── news.2008.en.shuffled
│   │   ├── news.2008.es.shuffled
│   │   ├── news.2008.fr.shuffled
│   │   ├── news.2009.cs.shuffled
│   │   ├── news.2009.de.shuffled
│   │   ├── news.2009.en.shuffled
│   │   ├── news.2009.es.shuffled
│   │   ├── news.2009.fr.shuffled
│   │   ├── news.2010.cs.shuffled
│   │   ├── news.2010.de.shuffled
│   │   ├── news.2010.en.shuffled
│   │   ├── news.2010.es.shuffled
│   │   ├── news.2010.fr.shuffled
│   │   ├── news.2011.cs.shuffled
│   │   ├── news.2011.de.shuffled
│   │   ├── news.2011.en.shuffled
│   │   ├── news.2011.es.shuffled
│   │   ├── news.2011.fr.shuffled
│   │   ├── news-commentary-v6.cs
│   │   ├── news-commentary-v6.de
│   │   ├── news-commentary-v6.en
│   │   ├── news-commentary-v6.es
│   │   └── news-commentary-v6.fr
│   ├── iam
│   │   ├── ascii
│   │   ├── forms
│   │   ├── largeWriterIndependentTextLineRecognitionTask
│   │   ├── lines
│   │   └── xml
│   ├── rimes
│   │   ├── eval_2011
│   │   ├── eval_2011_annotated.xml
│   │   ├── training_2011
│   │   └── training_2011.xml
│   ├── saintgall
│   │   ├── data
│   │   ├── ground_truth
│   │   ├── README.txt
│   │   └── sets
│   └── washington
│       ├── data
│       ├── ground_truth
│       ├── README.txt
│       └── sets
└── src
    ├── data
    │   ├── evaluation.py
    │   ├── generator.py
    │   ├── __init__.py
    │   ├── m2.py
    │   └── preproc.py
    ├── main.py
    ├── tool
    │   ├── __init__.py
    │   ├── symspell.py 
    │   └── transformer.py
    ├── transform
    │   ├── bea2019.py
    │   ├── bentham.py
    │   ├── conll13.py
    │   ├── conll14.py
    │   ├── google.py
    │   ├── iam.py
    │   ├── __init__.py
    │   ├── rimes.py
    │   └── washington.py
    └── tutorial.ipynb

```

After that, create virtual environment and install the dependencies with python 3 and pip:

> ```python -m venv .venv && source .venv/bin/activate```

> ```pip install -r requirements.txt```

### 1.3 Dataset folders

Now, you'll run the *transform* function from **main.py**. For this, execute on **src** folder:

> ```python main.py --dataset=<DATASET_NAME> --transform```

To work with **all** datasets, type it:
> ```python main.py --transform```

Your data will be preprocess and encode, creating and saving in the **data** folder. Now your project directory will be like this:


```
.
├── data
│   ├── all.txt
│   ├── bea2019.txt
│   ├── bentham.txt
│   ├── conll13.txt
│   ├── conll14.txt
│   ├── google.txt
│   ├── iam.txt
│   ├── rimes.txt
│   └── washington.txt
├── raw
│   ├── bea2019
│   │   ├── json
│   │   ├── json_to_m2.py
│   │   ├── licence.wi.txt
│   │   ├── license.locness.txt
│   │   ├── m2
│   │   └── readme.txt
│   ├── bentham
│   │   ├── BenthamDatasetR0-GT
│   │   └── BenthamDatasetR0-Images
│   ├── conll13
│   │   ├── m2scorer
│   │   ├── original
│   │   ├── README
│   │   ├── revised
│   │   └── scripts
│   ├── conll14
│   │   ├── alt
│   │   ├── noalt
│   │   ├── README
│   │   └── scripts
│   ├── google
│   │   ├── europarl-v6.cs
│   │   ├── europarl-v6.de
│   │   ├── europarl-v6.en
│   │   ├── europarl-v6.es
│   │   ├── europarl-v6.fr
│   │   ├── news.2007.cs.shuffled
│   │   ├── news.2007.de.shuffled
│   │   ├── news.2007.en.shuffled
│   │   ├── news.2007.es.shuffled
│   │   ├── news.2007.fr.shuffled
│   │   ├── news.2008.cs.shuffled
│   │   ├── news.2008.de.shuffled
│   │   ├── news.2008.en.shuffled
│   │   ├── news.2008.es.shuffled
│   │   ├── news.2008.fr.shuffled
│   │   ├── news.2009.cs.shuffled
│   │   ├── news.2009.de.shuffled
│   │   ├── news.2009.en.shuffled
│   │   ├── news.2009.es.shuffled
│   │   ├── news.2009.fr.shuffled
│   │   ├── news.2010.cs.shuffled
│   │   ├── news.2010.de.shuffled
│   │   ├── news.2010.en.shuffled
│   │   ├── news.2010.es.shuffled
│   │   ├── news.2010.fr.shuffled
│   │   ├── news.2011.cs.shuffled
│   │   ├── news.2011.de.shuffled
│   │   ├── news.2011.en.shuffled
│   │   ├── news.2011.es.shuffled
│   │   ├── news.2011.fr.shuffled
│   │   ├── news-commentary-v6.cs
│   │   ├── news-commentary-v6.de
│   │   ├── news-commentary-v6.en
│   │   ├── news-commentary-v6.es
│   │   └── news-commentary-v6.fr
│   ├── iam
│   │   ├── ascii
│   │   ├── forms
│   │   ├── largeWriterIndependentTextLineRecognitionTask
│   │   ├── lines
│   │   └── xml
│   ├── rimes
│   │   ├── eval_2011
│   │   ├── eval_2011_annotated.xml
│   │   ├── training_2011
│   │   └── training_2011.xml
│   ├── saintgall
│   │   ├── data
│   │   ├── ground_truth
│   │   ├── README.txt
│   │   └── sets
│   └── washington
│       ├── data
│       ├── ground_truth
│       ├── README.txt
│       └── sets
└── src
    ├── data
    │   ├── evaluation.py
    │   ├── generator.py
    │   ├── __init__.py
    │   ├── m2.py
    │   └── preproc.py
    ├── main.py
    ├── tool
    │   ├── __init__.py
    │   ├── symspell.py 
    │   └── transformer.py
    ├── transform
    │   ├── bea2019.py
    │   ├── bentham.py
    │   ├── conll13.py
    │   ├── conll14.py
    │   ├── google.py
    │   ├── iam.py
    │   ├── __init__.py
    │   ├── rimes.py
    │   └── washington.py
    └── tutorial.ipynb

```

Then upload the **data** and **src** folders in the same directory in your Google Drive.

## 2 Google Drive Environment


### 2.1 TensorFlow 2.0

Make sure the jupyter notebook is using GPU mode. Try to use **Tesla T4** instead of Tesla K80 (faster).

In [0]:
!nvidia-smi

Now, we'll install TensorFlow 2.0 with GPU support.

In [0]:
!pip install -q gast==0.2.2 tensorflow-gpu==2.0.0-beta1

In [0]:
import tensorflow as tf

device_name = tf.test.gpu_device_name()

if device_name != "/device:GPU:0":
    raise SystemError("GPU device not found")

print(f"Found GPU at: {device_name}")

### 2.2 Google Drive

Mount your Google Drive partition.

**Note:** *\"Colab Notebooks/text-correction/src/\"* was the directory where you put the project folders, specifically the **src** folder.

In [0]:
from google.colab import drive

drive.mount("./gdrive", force_remount=True)

%cd "./gdrive/My Drive/Colab Notebooks/text-correction/src/"
!ls -l

After mount, you can see the list os files in the project folder.

## 3 Set Python Classes

### 3.1 Environment

First, let's define our environment variables.

Set the main configuration parameters, such as dataset, method, number of epochs and batch size. This make compatible with **main.py** and jupyter notebook:

* **dataset**: "all", "bea2019", "bentham", "conll13", "conll14", "google", "iam", "rimes", "washington"

* **mode**: type of method: "seq2seq", "transformer"; or statistical "symspell" (localhost only)

* **epochs**: number of epochs

* **batch_size**: number size of the batch

In [0]:
import os

# define parameters
dataset = "all"
mode = "seq2seq"
epochs = 1000
batch_size = 64

# define paths
data_path = os.path.join("..", "data")
m2_src = os.path.join(data_path, f"{dataset}.txt")

output_path = os.path.join("..", "output", dataset, mode)
os.makedirs(output_path, exist_ok=True)

# define number max of chars per line and list of valid chars
max_text_length = 128
charset_base = "".join([chr(i) for i in range(32, 127)])
charset_special = "".join([chr(i) for i in range(192, 256)])

print("output", output_path)
print("charset:", charset_base + charset_special)

### 3.2 DataGenerator Class

The second class is **DataGenerator()**, responsible for:

* Load the dataset partitions (train, valid, test);

* Manager batchs for train/validation/test process.

In [0]:
from data.generator import DataGenerator

dtgen = DataGenerator(m2_src=m2_src,
                      batch_size=batch_size,
                      charset=(charset_base + charset_special),
                      max_text_length=max_text_length)

print(f"Train sentences: {dtgen.total_train}")
print(f"Validation sentences: {dtgen.total_valid}")
print(f"Test sentences: {dtgen.total_test}")

### 3.3 Neural Network Model

In this step, the model will be created/loaded and default callbacks setup.

In [0]:
import time
from data import preproc as pp, evaluation
from tool.seq2seq import Seq2SeqAttention
from tool.transformer import Transformer

if mode == "transformer":
    # disable one hot encode (from seq2seq) to use transformer model
    dtgen.one_hot_process(False)
    model = Transformer(num_layers=2, units=512, d_model=256, num_heads=4,
                        dropout=0.1, tokenizer=dtgen.tokenizer)

elif mode == "seq2seq":
    # increase the amount noise value to make a incremental learning process
    dtgen.increase_noise(0.001, from_up=0)
    model = Seq2SeqAttention(units=128, dropout=0.1, tokenizer=dtgen.tokenizer)

# set parameter `learning_rate` to customize or set `None` to get default schedule function
model.compile(learning_rate=0.001)

# save network summary
model.summary(output_path, "summary.txt")

# get default callbacks list and load checkpoint weights file (HDF5) if exists 
checkpoint = "checkpoint_weights.hdf5"
callbacks = model.get_callbacks(logdir=output_path, hdf5=checkpoint, verbose=1)

model.load_checkpoint(target=os.path.join(output_path, checkpoint))

## 4 Tensorboard

To facilitate the visualization of the model's training, you can instantiate the Tensorboard. 

**Note**: All data is saved in the output folder

In [0]:
%load_ext tensorboard
%tensorboard --reload_interval=2000 --logdir={output_path}

## 5 Training

The training process using *fit_generator()* to fit memory. After training, the information (epochs and minimum loss) is save.

In [0]:
# to calculate total and average time per epoch
start_time = time.time()
h = model.fit_generator(generator=dtgen.next_train_batch(),
                        epochs=epochs,
                        steps_per_epoch=dtgen.train_steps,
                        validation_data=dtgen.next_valid_batch(),
                        validation_steps=dtgen.valid_steps,
                        callbacks=callbacks,
                        shuffle=True,
                        verbose=1)
total_time = time.time() - start_time

loss = h.history['loss']
accuracy = h.history['accuracy']

val_loss = h.history['val_loss']
val_accuracy = h.history['val_accuracy']

time_epoch = (total_time / len(accuracy))
total_item = (dtgen.total_train + dtgen.total_valid)
best_epoch_index = val_accuracy.index(max(val_accuracy))

train_corpus = "\n".join([
    f"Total train sentences:      {dtgen.total_train}",
    f"Total validation sentences: {dtgen.total_valid}",
    f"Batch:                      {dtgen.batch_size}\n",
    f"Total epochs:               {len(accuracy)}",
    f"Total time:                 {total_time:.8f} sec",
    f"Time per epoch:             {time_epoch:.8f} sec",
    f"Time per item:              {(time_epoch / total_item):.8f} sec\n",
    f"Best epoch                  {best_epoch_index + 1}",
    f"Training loss:              {loss[best_epoch_index]:.8f}",
    f"Training accuracy:          {accuracy[best_epoch_index]:.8f}\n",
    f"Validation loss:            {val_loss[best_epoch_index]:.8f}",
    f"Validation accuracy:        {val_accuracy[best_epoch_index]:.8f}"
])

with open(os.path.join(output_path, "train.txt"), "w") as lg:
    lg.write(train_corpus)
    print(train_corpus)

## 6 Predict and Evaluate

Since the goal is to correct text, the metrics (CER and WER) are calculated before and after of the correction.

The predict process also using the *predict_generator()*:

In [0]:
start_time = time.time()
predicts = model.predict_generator(generator=dtgen.next_test_batch(),
                                   steps=dtgen.test_steps,
                                   use_multiprocessing=True,
                                   verbose=1)
total_time = time.time() - start_time

# calculate metrics (before and after)
old_metric = evaluation.ocr_metrics(dtgen.dataset["test"]["dt"], dtgen.dataset["test"]["gt"])
new_metric = evaluation.ocr_metrics(predicts, dtgen.dataset["test"]["gt"])

# generate report
pred_corpus, eval_corpus = evaluation.report(dtgen, predicts, [old_metric, new_metric], total_time)

with open(os.path.join(output_path, "predict.txt"), "w") as lg:
    lg.write("\n".join(pred_corpus))
    print("\n".join(pred_corpus[:30]))

with open(os.path.join(output_path, "evaluate.txt"), "w") as lg:
    lg.write(eval_corpus)
    print(eval_corpus)