## Creates TensorFlow Graphs for Spark NLP NerDLApproach
TensorFlow: `1.13.1`


In [None]:
import numpy as np
import os
import tensorflow as tf
import string
import random
import math
import sys
import shutil

from ner_model import NerModel
from dataset_encoder import DatasetEncoder
from ner_model_saver import NerModelSaver
from pathlib import Path

In [None]:
print(tf.__version__)
print(tf.keras.__version__)

## SETTINGS

In [None]:
use_contrib = False if os.name == 'nt' else True

name_prefix = 'blstm-noncontrib' if not use_contrib else 'blstm'

In [None]:
# By default the first GPU is used, you can change it here
gpu_device=0

In [None]:
def create_graph(ntags, embeddings_dim, nchars, lstm_size = 128):
    if sys.version_info[0] != 3 or sys.version_info[1] >= 7:
        print('Python 3.6 or above not supported by tensorflow')
        return
    if tf.__version__ != '1.13.1':
        print('Spark NLP is compiled with TensorFlow 1.13.1, Please use such version.')
        print('Current TensorFlow version: ', tf.__version__)
        return
    tf.compat.v1.disable_v2_behavior()
    tf.compat.v1.reset_default_graph()
    model_name = name_prefix+'_{}_{}_{}_{}'.format(ntags, embeddings_dim, lstm_size, nchars)
    with tf.compat.v1.Session() as session:
        ner = NerModel(session=None, use_contrib=use_contrib, use_gpu_device=gpu_device)
        ner.add_cnn_char_repr(nchars, 25, 30)
        ner.add_bilstm_char_repr(nchars, 25, 30)
        ner.add_pretrained_word_embeddings(embeddings_dim)
        ner.add_context_repr(ntags, lstm_size, 3)
        ner.add_inference_layer(True)
        ner.add_training_op(5)
        ner.init_variables()
        saver = tf.compat.v1.train.Saver()
        file_name = model_name + '.pb'
        tf.io.write_graph(ner.session.graph, './', file_name, False)
        ner.close()
        session.close()

### Attributes info
- 1st attribute: max number of tags (Must be at least equal to the number of unique labels, including O if IOB)
- 2nd attribute: embeddings dimension
- 3rd attribute: max number of characters processed (Must be at least the largest possible amount of characters)
- 4th attribute: LSTM Size (128)

In [None]:
# # CoNLL 2003 - English - GloVe 100d
# create_graph(9, 100, 90)

# # CoNLL 2003 - English - GloVe 200d
# create_graph(9, 200, 90)

# # CoNLL 2003 - English - GloVe 300d
# create_graph(9, 300, 90)

# # CoNLL 2003 - English - BERT Base
# create_graph(9, 768, 90)

# # CoNLL 2003 - English - BERT Large
# create_graph(9, 1024, 90)

# # You got the idea :)

In [None]:
os._exit(00)