Skip to content

NoAchache/TextBoxGAN

Repository files navigation

TextBoxGAN

Generates text boxes from input words with a Generative Adversial Network.

Video: Generating the word "Generate" at different training steps: https://youtu.be/YdicGxqRWOY

words with the same style

Figure 1: Different examples of generating "Words with the same style" using our model

Contents

Requirements

For Training or using the Projector:

  • 1 or more CUDA compatible GPUs (making it work on CPU is in progress)

PS: Inference and the Projector now works on CPU

Install with docker

Build the image:

docker build -t textboxgan .

Run the docker image:

docker run --gpus all -it -v `pwd`:/TextBoxGAN textboxgan bash

Install with poetry

poetry install

Download datasets / models

Download the models

The following models are available in this google drive:

  • trained model: pretrained model (c.f. the Results section for more details on the model). Place this directory in the experiments directory. To use it, replace EXPERIMENT_NAME = None with EXPERIMENT_NAME = "trained model", and ensure cfg.resume_step = 225000 in the config file.

Required for training and running the projector:

  • aster_weights: weights of the ASTER OCR converted to tf2. Place this directory at the root of the project.

Required for running the projector:

  • perceptual_weights: weights of the perceptual loss, converted from pytorch using a repo of moono. Place this directory in the projector directory.

Download and make datasets

make download-and-make-datasets

All the following commands should be run within the docker.

Train

Specify all the configs in config.py.

poetry run python train.py

Generate chosen words

Generate "Hello" and "World" 20 times:

poetry run python infer.py --infer_type "chosen_words" --words_to_generate "Hello" "World" --num_inferences 20 --output_dir "<output directory>"

Infer the test set

Get an average over 50 runs (since random vectors are used, the test set result is not constant):

poetry run python infer.py --infer_type "test_set" --num_test_set_run 50

Run the projector

poetry run python -m projector.projector --target_image_path "<path/to/image>" --text_on_the_image "<text on image>" --output_dir "<output directory>"

Launch tensorboard

Display the logs of one or several experiments (e.g. xp1 and xp2):

make tensorboard xps="xp1 xp2"

If your experiments are stored in a VM (with which you are connected via ssh), add the following alias to your bashrc or zshrc:

function vm_tensorboard() {ssh -t -L 6006:localhost:6006 <vm adress> 'cd <path/to/your/repo> && make tensorboard xps='\"$1\"''}

You can now access the tensorboard UI locally:

vm_tensorboard "xp1 xp2"

Access the UI

Tensorboard default port is 6006. Hence the UI can be accessed on the following address: http://localhost:6006/

Technical Documentation

network

Figure 2: Network architecture. The green and brown arrows represents the backpropagation of respectively the OCR and GAN losses

word encoder

Figure 3: Word Encoder

All hyperparameters stated below can be configured in config.py.

Inputs:

  • Words whose length are less or equal to max_char_number.
  • A normally distributed random vector of shape (batch_size, z_dim).

Word Encoder (code): Computes a vector representation for each character of the input word using an embedding followed by a dense layer. The tensor containing all the characters encoded is then reshaped, while preserving the order of the characters. This tensor is then 0 padded since the generator requires a fixed size input, regardless of the word's length. This process is described in Figure 3. In the code, the integer sequence representing the word is 0 padded in the data loader , and the corresponding embedding is filled with 0s.

Latent Encoder (code) (Refer to the StyleGan2 paper for more information) : Generates the style of the image by encoding the noise vector through n_mapping dense layers, leading to a tensor of size (batch_size, style_dim) where style_dim is usually equal to z_dim. The resulting tensor is duplicated n_style times, where n_style is equal to the number of layers in the Synthesis Network. During training, a second style vector is generated and is mixed with the first one to broaden the distribution of the generated images. During evaluation, the style vector can be truncated to vary its intensity.

Synthesis Network (code) *(Refer to the StyleGan2 paper for more information)*: Synthesises a text box from a word encoded tensor and a style vector. At each layer, the style vectors are directly applied on the kernels instead of the image for efficiency purpose. RGB images are generated at each upscaling and added up, allowing the network to train better globally (and not only train the final layers).

Handling words of different sizes: The text box generated by the synthesis network has a fixed width, regardless of the number N of characters in the input word. To handle words of different sizes, and hence images of different widths, a mask is applied on the text box before feeding it to the discriminator. It is assumed that each character has a width char_width = image_width / max_char_number. Thus, the mask zeroes all elements where the corresponding width is beyond N * char_width. This both allows to force the synthesis network to not distort characters, and prevents the discriminator from using residual noise as a pattern to distinguish from fake and real boxes. Similarly, real text boxes are resized and zero padded according to the length of their label.

Discriminator (code) *(Refer to the StyleGan2 paper for more information)*: Classifies the text boxes into fake or real.

OCR (ASTER) (code): *(Refer to the ASTER paper for more information)*: Recognizes the characters present in a text box with a combination of a CNN and an attention RNN. A pre-trained model is utilised to train our model and the weights are fixed. Since the RNN is bidirectional, it outputs two predictions: one with the word written in the right direction, and the other one backward. Combining both predictions allows, in theory, to reach a more accurate prediction (not with our pre-trained model). The prediction consists in a 2D array of logits, representing the probability of being each possible characters for each character identified in the text box.

OCR Loss (code): Two possible losses were experimented:

  • Softmax crossentropy (SCE): applied between the OCR's output logits of the generated text box and the label, i.e. the input word.
  • Mean squared error (MSE): applied between the OCR's output logits of the generated text box and those of a real text box, with the same label. Hence, our problem is modeled as a regression rather than a classification.

GAN Loss / Regularization:

  • SoftPlus (code): applied on the discriminator prediction, to train both the discriminator and the generator.
  • Path Length Regularization (code): encourages the derivative of the generated image with respect to the style vector generated by the Latent Encoder to remain constant. As result, a change in the style vector leads to a proportional perceptual change in the output image, leading to a more consistent behavior of the generator. More details here.
  • R1 Regularization (code): stabilizes the overall training through the application of a gradient penalty on the Discriminator, forcing it to remain within the Nash Equilibrium. More details here.

The GAN loss is not propagated through the Word Encoder, as it encourages the latter to output a fixed value, regardless of the word, since it is easing the task of the Synthesis Network.

The datasets used are composed of:

  • Training:

    • 76 653 Latin text boxes from ICDAR MLT 17 and MLT 19 and their corresponding labels. Bad images (e.g. text boxes written vertically) are filtered out by computing the ocr loss of each text box, and only keeping those which fall below a threshold.

    • 121 175 words from a wikipedia corpus and an English dictionary.

  • Validation: 5000 words from a wikipedia corpus and an English dictionary.

  • Test: 5000 words from a wikipedia corpus and an English dictionary.

The distribution of the characters of the text boxes' labels follows the distribution of characters in latin languages. As a result, some characters (e.g. "e") appear significantly more often than others (e.g. "q"). To allow the network to perform well when generating every character, another dataset is used, with a more homogeneous character distribution, as detailed in the table below. This dataset is composed of words selected from a wikipedia corpus and an english dictionary: when building it, each word added must contain the character currently appearing the least in the dataset. However, this dataset is not compatible with the OCR mean squared error loss, since such loss requires the text box corresponding to the word.

character distribution

Figure 4: Character Distribution

Unlike the GAN Loss, the OCR loss converges and is hence, the only metric measured during the inference of the Validation and Test datasets. An average over many runs can be computed for a more accurate result. Indeed, since random vectors are used to generate the style, the output score may vary from one run to another.

The weights used for inference are the moving average of the trained weights (denoted as g_clone in the code). Re-using the same style vector for generating two different words allows to endow them with the same font and background. The text box generated is cropped depending on the word's length (c.f. handling words of different sizes).

Comparing different training strategies

losses comparison

Figure 5.a: Losses tracked during 100K steps, with a batch of 4, for different training strategies. The OCR loss is consistently the Softmax Crossentropy loss, regardless of the loss used for training.

generating_words_mse_vs_sce

Figure 5.b: Examples of words generated with models trained with different OCR losses. The input words are shown on the left and the first and the second column of images correspond respectively to the results obtained with the model "SCE: Not using the corpus dataset" (c.f. Figure 5.a), and to the results obtained with the model trained with the MSE.

Even though different OCR losses are used in the experiments, the validation metric is always the SCE loss, since it is a good indicator of whether the text box is readable and contains the right text. As observed in Figure 5.a, when training with the MSE, the validation loss is approximately 5 times larger than when training with the SCE. The bad performance of the MSE training is emphasized in Figure 6, where the text boxes generated with the corresponding model are blurry and not always matching to the input word.

Furthermore, randomly selecting input words from the corpus dataset did not noticeably decrease the OCR loss. However, as shown in Figure 6, it allowed the network to generate more neatly the characters appearing the least in the text box dataset.

When training on a large number of steps, the discriminator will ultimately end up beating the generator for our model. Hence, it is better to have a low generator loss at the beginning of the training. Following this observation, and from Figure 5, it can be deduced that selecting the input word from the corpus dataset with a 0.25 probability is preferable to using a 0.5 probability.

swapping labels

Figure 6: Emphasizing the importance of using the corpus dataset

Training the final model

Following the above observation, the network was trained on 225K steps, with a batch of 16, with the Softmax Crossentropy loss (OCR loss) and using input words from the corpus dataset with a 0.25 probability.

The Softmax Crossentropy loss obtained on the Test set for this model is 6.38 (average over 100 runs). In comparison, the mean loss obtained when inferring the images of the dataset in the OCR is 1.27. Some style vectors lead to hardly readable text boxes, which justifies the difference between the two losses. Words generated with this model are shown in Figure 1.

Limitations of the model

We observed three main limitations of our model. Areas for improvement regarding these limitations are stated in the conclusion:

  • Changing the style has a weak effect on the shape of the characters (c.f. Figure 1). For instance, there are different manners to form an "a", but the network always uses the same shape. From the video of the network generating the word "generate" at different steps, it can be observed that the model first learns to generate the character "a" with a shape that resembles "o" (around 0:11). Later, it learns to generate it differently (around 0:15), with a shape that cannot be mistaken with any other character. This behaviour is the result of using the Softmax Crossentropy loss as the OCR loss.
  • The method used to handle words of different lengths constrains some ratios such as the size of the characters or the spacing in between the characters.
  • The words generated only use printed characters, since it is also the case in the text boxes of the datasets used (MLT 17 and 19)

The projector extracts the style from a text box to generate new words with the same style (code). To do so, the latent vector responsible for the style of the image is trained using a perceptual loss between the target image and a generated text box with the same word. An ocr loss is also utilised, preventing the network from finding a style leading to unreadable text.

The style vector is saved in a .txt file and can be re-used to generate more words with the same styles. Below is an example of how the style obtained projecting a text box can be utilised to generate new words.

projector

Figure 7: Projector examples. The top line corresponds to the original images, and the three others to words generated using the style found by the projector.

TextBoxGAN can generate readable text boxes corresponding from an input word, with various styles. However, from the limitations stated above, it can be deduced that an OCR will not generalise enough if trained only with data generated with our model. Hence, our model may be more appropriate for data augmentation, i.e. training with a mix of generated text boxes and real text boxes, at the risk of creating a bias towards certain characters shapes. However, considering that, at least to our knowledge, it is the first attempt to generate text boxes with a GAN, the results obtained are very satisfying.

Areas for improvement

To attempt to overcome some of the limitations identified for our model, the following ideas could be implemented:

  • Training the model switching between the MSE and the SCE losses. Indeed, when using the MSE, the objective is not to write the perfect word, but rather to write the word as it is on the real text box. Doing so could allow the model to generate characters of different shapes when changing the style.
  • Using of a recurrent layer in the word encoder may enable our model to generate words with linked characters (if a dataset with linked characters is used).

These are the different repositories used in our implementation:

Contact: noea@sicara.com