In [None]:
import json
from pprint import pprint
import zipfile
import time

import numpy as np
import random as rnd
import pandas as pd
import regex as re

from progressbar import ProgressBar

### Initialize required packages and input relevant models

In [None]:
import spacy
print("Spacy version", spacy.__version__)
lang_model = 'en_core_web_sm'
nlp = spacy.load(lang_model)

#uncomment line below to accelerate training with GPU
#spacy.require_gpu()

In [None]:
import gensim
from gensim.models import Word2Vec
#from gensim.models.fasttext import FastText
print("Gensim verion:", gensim.__version__)

#Please provide your own word embedding
embeddings = Word2Vec.load("path_to_embedding_model")

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Activation, SimpleRNN, Bidirectional, Dropout

print("TF version:", tf.__version__)
print("Keras version", keras.__version__)

In [None]:
from synthesis_action_retriever.utils import make_spacy_tokens
from text_cleanup import TextCleanUp
tc = TextCleanUp()

### Upload annotated data

In [None]:
path_to_dataset = './data/synthesis_action_annotated_dataset_2021-10-17.json'

with open(path_to_dataset, 'r') as fp:
    annotated_data = json.load(fp)

In [None]:
print("Number of annotated sentences: ", len(annotated_data))

In [None]:
min_tok_thresh = 5
max_tok_thresh = 50
all_sentences = [
    s for s in annotated_data 
    if len(s["annotations"]) > min_tok_thresh and len(s["annotations"]) < max_tok_thresh
]
print("Number of sentences for training after thresholding tokens: ", len(all_sentences))

### Utils

In [None]:
elements_1 = ['H', 'B', 'C', 'N', 'O', 'F', 'P', 'S', 'K', 'V', 'Y', 'I', 'W', 'U']
elements_2 = ['He', 'Li', 'Be', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'Cl', 'Ar', 'Ca', 'Sc', 'Ti', 'Cr',
              'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr',
              'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'Xe',
              'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er',
              'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi',
              'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf',
              'Es', 'Fm', 'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Fl', 'Lv']
num_set = set("0987654321+-()[]")

def is_formula_like(tok):
    if all(c.islower() and not c.isdigit() for c in tok[1:]):
        return False
    
    token_subs = tok
    for el in elements_2:
        token_subs = token_subs.replace(el, "")
    for el in elements_1:
        token_subs = token_subs.replace(el, "")
    if len(token_subs) < len(tok):
        return True
    
    return False

def is_num_like(tok):
    if len([c for c in tok if c in num_set])/len(tok) > 0.5:
        return True
    
    if all(not c.isalpha() for c in tok): 
        return True
    if tok[0].isdigit() and tok.islower():
        return True

    return False

def replace_token_upd(tok, mode):
    if is_num_like(tok.text) and len(tok.text) > 1:
        return '<num>'
    
    if is_formula_like(tok.text) and len(tok.text) > 1:
        return '<chem>'
    
    if mode == 'lemma':
        return tok.lemma_
#         if lemmas_freq[tok.lemma_] < 2:
#             return '<unk>'
#         else:
#             return tok.lemma_
    else:
        return tok.text.lower()
#         if words_freq[tok.text.lower()] < 2:
#             return '<unk>'
#         else:
#             return tok.text.lower()

    return tok.text

### Setting labels

In [None]:
action2num = {
    "": 0,
    'Non-altering': 0,
    'Starting': 1,
    'Mixing': 2,
    'Purification': 3,
    'Heating': 4,
    'Shaping': 5,
    'Cooling': 6,
    'Reaction': 7
}

num2action = {
    0: "",
    1: 'Starting',
    2: 'Mixing',
    3: 'Purification',
    4: 'Heating',
    5: 'Shaping',
    6: 'Cooling',
    7: 'Reaction'
}

### Featurize tokens

In [None]:
stop_list = ['the', 'a', 'an', 'oftentimes', 'however', 'moreover', 'therefore', 'whereas', 'whereby', 'hence', 
             'thus', 'where']

def get_embeddings(word_, embed_model):
    
    if word_ in ["<start>", "<end>"]:
        return np.zeros(embed_model.trainables.layer1_size, dtype=float)
    
    word = tc.cleanup_text(word_).lower()
    if word in embed_model.wv.vocab:
        return embed_model.wv.__getitem__(word)
    else:
        return embed_model.wv.__getitem__("<unk>")

In [None]:
train_frac = 0.8

num_classes = len(num2action)
featurized_sentences = []

training_size = int(train_frac*len(all_sentences))
print("Training size:", training_size)
print("Test size:", len(all_sentences) - training_size)

rnd.shuffle(all_sentences)

test_sents = all_sentences[training_size:]
training_sents = all_sentences[0: training_size]

bar = ProgressBar(max_value=len(training_sents))

for num, sentence in enumerate(training_sents):
    sentence_features = []
    sentence_labels = []

    spacy_tokens = spacy.tokens.Doc(nlp.vocab, words = [a["token"] for a in sentence["annotations"]])
    
    sentence_features.append(get_embeddings("<start>", embeddings))
    sentence_labels.append(np.zeros(num_classes))
    
    for word, annot in zip(spacy_tokens, sentence["annotations"]):
        embed_vec = get_embeddings(replace_token_upd(word, mode=""), embeddings)
        action_vec = keras.utils.to_categorical(action2num[annot["tag"]], num_classes)
        
        sentence_features.append(embed_vec)
        sentence_labels.append(action_vec)

    sentence_features.append(get_embeddings("<end>", embeddings))
    sentence_labels.append(np.zeros(num_classes))

    featurized_sentences.append(dict(
            data = sentence_features,
            labels = sentence_labels
        ))
    
    bar.update(num)

print(len(featurized_sentences))
print(len(featurized_sentences[0]['data'][0]))  

In [None]:
input_word_dim = embeddings.trainables.layer1_size
seq_len = max([len(d["data"]) for d in featurized_sentences])
output_dim = num_classes

print("Input word dimention:", input_word_dim)
print("Input sequence length:", seq_len)
print("Output dimention:", output_dim)
print("Output sequence length (same as input):", seq_len)

In [None]:
input_sentences_data = np.zeros((len(featurized_sentences), seq_len, input_word_dim), dtype='float32')
output_tags_data = np.zeros((len(featurized_sentences), seq_len, output_dim), dtype='float32')

for i, data in enumerate(featurized_sentences[0:training_size]):
    for t, (word, tag) in enumerate(zip(data["data"], data["labels"])):
        input_sentences_data[i, t] = word
        output_tags_data[i, t] = tag

In [None]:
print(input_sentences_data[0].shape)
print(output_tags_data[0].shape)

In [None]:
latent_dim = 32

model = None
X = Input(shape=(None, input_word_dim))
#lstm = SimpleRNN(latent_dim, return_sequences=True)(X)
lstm = Bidirectional(SimpleRNN(latent_dim, return_sequences=True, dropout=0.2, recurrent_dropout=0.2))(X)
dense = Dense(output_dim)(lstm)
prediction = Activation("softmax")(dense)
model = Model(inputs=X, outputs=prediction)

model.summary()

In [None]:
batch_size = 128
epochs = 64
model.compile(optimizer='rmsprop', loss='categorical_crossentropy',
              metrics=['accuracy'])
model.fit([input_sentences_data], output_tags_data,
          batch_size=batch_size,
          epochs=epochs,
          validation_split=0.2)

In [None]:
idx = rnd.randint(0, len(all_sentences)-1)
test_sent = all_sentences[idx]

test_words = [a["token"] for a in test_sent["annotations"]]
test_tags = [a["tag"] for a in test_sent["annotations"]]

print(test_words)
print([num2action[action2num[t]] for t in test_tags])

spacy_tokens = spacy.tokens.Doc(nlp.vocab, words = test_words)

input_sentences_data = np.zeros((1, seq_len, input_word_dim), dtype='float32')
for t, (word, tag) in enumerate(zip(spacy_tokens, test_tags)):
    embed_vec = get_embeddings(replace_token_upd(word, mode=""), embeddings)
    input_sentences_data[0, t] = embed_vec
    
result = model.predict(input_sentences_data)[0]
tags_predicted = []
for word, pred_vec in zip(test_words, result):
    tags_predicted.append(num2action[np.argmax(pred_vec)])

print(tags_predicted)

### Test

In [None]:
correct = []
missing = []
extra = []
wrong_tag = []

tn = []
tp = []
fp = []
fn = []


bar = ProgressBar(max_value = len(test_sents))
for sentence in test_sents:
    
    words = [a["token"] for a in sentence["annotations"]]
    tags = [a["tag"] for a in sentence["annotations"]]
    
    correct_tags = []
    for t in tags:
        t = "Mixing" if "Mixing" in t else t
        t = "" if t == "Miscellaneous" else t
        correct_tags.append(t)
    
    spacy_tokens = spacy.tokens.Doc(nlp.vocab, words = words)

    seq_len = len(words)
    input_sentences_data = np.zeros((1, seq_len, input_word_dim), dtype='float32')
    input_sentences_data[0, 0] = get_embeddings("<start>", embeddings)
    for t, word in enumerate(spacy_tokens):
        embed_vec = get_embeddings(replace_token_upd(word, mode=""), embeddings)
        input_sentences_data[0, t] = embed_vec
    input_sentences_data[0, -1] = get_embeddings("<end>", embeddings)

    result = model.predict(input_sentences_data)[0]
    tags_predicted = [num2action[np.argmax(v)] for v in result]#[0:len(spacy_tokens)]
    
    sentence["prediction"] = tags_predicted#[1:-1]
    sentence["correct"] = correct_tags

    if tags_predicted == correct_tags:
        correct.append(sentence)
        if "".join([t for t in correct_tags+tags_predicted]) == "":
            tn.append(sentence)
        else:
            tp.append(sentence)
    elif len([t for t in tags_predicted if t != ""]) > len([t for t in correct_tags if t != ""]):
        extra.append(sentence)
        fp.append(sentence)
    elif len([t for t in tags_predicted if t != ""]) < len([t for t in correct_tags if t != ""]):
        missing.append(sentence)
        fn.append(sentence)
    else:
        wrong_tag.append(sentence)
    
print("Correct:", len(correct))
print("Extra:", len(extra))
print("Missing:", len(missing))
print("Wrong:", len(wrong_tag))
print("Test set:", len(test_sents))

In [None]:
prec = len(tp)/(len(tp)+len(fp))
recall = len(tp)/(len(tp)+len(fn))
accuracy = (len(tp) + len(tn))/(len(tp)+len(tn)+len(fp)+len(fn))
f1 = 2.0*prec*recall/(prec + recall)

print("Precision:", round(prec, 2))
print("Recall:", round(recall, 2))
print("Accuracy:", round(accuracy, 2))
print("F1:", round(f1, 2))

In [None]:
timestr = time.strftime("%Y%m%d-%H%M%S")
if lang_model=='en_core_web_trf':
    tf.saved_model.save(model, './output/Bi-RNN_cl7_ed100_TF_{}'.format(timestr))
else:
    model.save("./output/Bi-RNN_cl7_ed100_{}".format(timestr))