In [2]:
from tensorflow.keras.layers import Input, Lambda, Bidirectional, Dense, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras import optimizers
from keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input,LSTM, Embedding, Dense, TimeDistributed, Dropout, Conv1D,GRU,BatchNormalization, Concatenate
from tensorflow.keras.layers import Bidirectional, concatenate, SpatialDropout1D, GlobalMaxPooling1D,dot,Activation
from tensorflow.keras import optimizers
from keras.utils import to_categorical
import numpy as np
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import ReduceLROnPlateau
from seqeval.metrics import classification_report, accuracy_score
import transformers
import torch
from keras import backend as K 
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint,ReduceLROnPlateau

In [3]:
from transformers import BertModel, BertConfig
tokenizer_class = transformers.BertTokenizer
model_class = transformers.BertModel
config = BertConfig.from_pretrained('gsarti/biobert-nli', output_hidden_states=True)
tokenizer = tokenizer_class.from_pretrained('gsarti/biobert-nli')
bert_model = model_class.from_pretrained('gsarti/biobert-nli',config=config)

In [19]:
#Import and process training and test data
import pickle
with open('tagged_abstracts.pickle', 'rb') as handle:
    abstracts = pickle.load(handle) 
    
words, tags = [], []
for sent in abstracts:
    for wrd in sent:
        words.append(wrd[0])
        tags.append(wrd[1])
        
words = list(set(words))        
words.append('ENDPAD')
tags = list(set(tags))      

max_len = 200
word2idx = {w: i+1 for i, w in enumerate(words)}
tag2idx = {t: i for i, t in enumerate(tags)}
idx2words = {}
for w,i in word2idx.items():
    idx2words[i] = w

docs = [" ".join([w[0] for w in s]) for s in abstracts]
yy =  [[w[1] for w in s] for s in abstracts]

X = [[word2idx[w[0]] for w in s] for s in abstracts]
X = pad_sequences(maxlen=max_len, sequences=X, padding="post", value= 0)

n_tags = 3
y = [[tag2idx[w[1]] for w in s] for s in abstracts]
y = pad_sequences(maxlen=max_len, sequences=y, padding="post", value=tag2idx["O"],dtype=object)
y = [to_categorical(i, num_classes=n_tags) for i in y]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,random_state=2018)


In [20]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,random_state=2018)

In [12]:
#Get Bert Embeddings

def get_word_embeddings(text):
    
    marked_text = "[CLS] " + text + " [SEP]"

    tokenized_text = tokenizer.tokenize(marked_text)
    if len(tokenized_text) > 500:
        tokenized_text = tokenized_text[:500]

    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    segments_ids = [1] * len(tokenized_text)
    
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])

    with torch.no_grad():
        encoded_layers = bert_model(tokens_tensor, segments_tensors)
        
    token_embeddings = torch.stack(encoded_layers[2], dim=0)   
    
    token_embeddings = torch.squeeze(token_embeddings, dim=1)
    
    token_embeddings = token_embeddings.permute(1,0,2)
    
    token_vecs_sum = []


    # For each token in the sentence...
    for token in token_embeddings:

        # Sum the vectors from the last four layers.
        sum_vec = torch.sum(token[-4:], dim=0)

        # Use `sum_vec` to represent `token`.
        token_vecs_sum.append(sum_vec)
        
    new_tokens = []    
    new_vecs = []
    for i, token_str in enumerate(tokenized_text):
        if token_str.startswith("##"):
            new_tokens[-1] = new_tokens[-1] + token_str[2:]
            new_vecs[-1] = torch.mean(torch.stack((new_vecs[-1],token_vecs_sum[i])), dim=0)
            
        else:
            new_tokens.append(token_str)
            new_vecs.append(token_vecs_sum[i])
    for tk, vec in zip(new_tokens,new_vecs):
        final_embed_vecs[tk] = vec
        

final_embed_vecs = {}
for item in docs:
    get_word_embeddings(item)       

In [16]:
#Word input embedding matrix
vocab_size = len(words) + 1
embedding_matrix = np.random.rand(vocab_size,768)
for word,i in word2idx.items():

    if word in final_embed_vecs:
        embedding_matrix[i] = final_embed_vecs[word]
    else:
        embedding_matrix[i] = np.random.rand(1,768) 

In [21]:
#Document embedding for train and test sets
#We average the Bert word embeddings for document

import re
train_matrix = []
for item in X_train:
    col = []
    agg = []
    count = 0
    for it in item:
        try:
            wrd = idx2words[it]
            wrd = re.sub('[^a-zA-Z-]+', '', wrd)
            col.append(final_embed_vecs[wrd])
            count+=1
        except:
            pass
    avrg = torch.mean(torch.stack(col), dim=0)   
    for it in item:
        agg.append(avrg.detach().numpy())
    
    train_matrix.append(agg)  
    
    
test_matrix = []
for item in X_test:
    col = []
    agg = []
    count = 0
    for it in item:
        try:
            wrd = idx2words[it]
            wrd = re.sub('[^a-zA-Z-]+', '', wrd)
            col.append(final_embed_vecs[wrd])
            count+=1
        except:
            pass
    avrg = torch.mean(torch.stack(col), dim=0)   
    for it in item:
        agg.append(avrg.detach().numpy())
    
    test_matrix.append(agg)     

In [25]:
# CRF Class 

from tensorflow_addons.text import crf_log_likelihood, crf_decode
import tensorflow.keras.backend as K
import tensorflow.keras.layers as L
import tensorflow as tf

class CRF(L.Layer):
    def __init__(self,
                 output_dim,
                 sparse_target=True,
                 **kwargs):
        """    
        Args:
            output_dim (int): the number of labels to tag each temporal input.
            sparse_target (bool): whether the the ground-truth label represented in one-hot.
        Input shape:
            (batch_size, sentence length, output_dim)
        Output shape:
            (batch_size, sentence length, output_dim)
        """
        super(CRF, self).__init__(**kwargs)
        self.output_dim = int(output_dim) 
        self.sparse_target = sparse_target
        self.input_spec = L.InputSpec(min_ndim=3)
        self.supports_masking = False
        self.sequence_lengths = None
        self.transitions = None

    def build(self, input_shape):
        assert len(input_shape) == 3
        f_shape = tf.TensorShape(input_shape)
        input_spec = L.InputSpec(min_ndim=3, axes={-1: f_shape[-1]})

        if f_shape[-1] is None:
            raise ValueError('The last dimension of the inputs to `CRF` '
                             'should be defined. Found `None`.')
        if f_shape[-1] != self.output_dim:
            raise ValueError('The last dimension of the input shape must be equal to output'
                             ' shape. Use a linear layer if needed.')
        self.input_spec = input_spec
        self.transitions = self.add_weight(name='transitions',
                                           shape=[self.output_dim, self.output_dim],
                                           initializer='glorot_uniform',
                                           trainable=True)
        self.built = True

    def compute_mask(self, inputs, mask=None):
        # Just pass the received mask from previous layer, to the next layer or
        # manipulate it if this layer changes the shape of the input
        return mask

    def call(self, inputs, sequence_lengths=None, training=None, **kwargs):
        sequences = tf.convert_to_tensor(inputs, dtype=self.dtype)
        if sequence_lengths is not None:
            assert len(sequence_lengths.shape) == 2
            assert tf.convert_to_tensor(sequence_lengths).dtype == 'int32'
            seq_len_shape = tf.convert_to_tensor(sequence_lengths).get_shape().as_list()
            assert seq_len_shape[1] == 1
            self.sequence_lengths = K.flatten(sequence_lengths)
        else:
            self.sequence_lengths = tf.ones(tf.shape(inputs)[0], dtype=tf.int32) * (
                tf.shape(inputs)[1]
            )

        viterbi_sequence, _ = crf_decode(sequences,
                                         self.transitions,
                                         self.sequence_lengths)
        output = K.one_hot(viterbi_sequence, self.output_dim)
        return K.in_train_phase(sequences, output)

    @property
    def loss(self):
        def crf_loss(y_true, y_pred):
            y_pred = tf.convert_to_tensor(y_pred, dtype=self.dtype)
            log_likelihood, self.transitions = crf_log_likelihood(
                y_pred,
                tf.cast(K.argmax(y_true), dtype=tf.int32) if self.sparse_target else y_true,
                self.sequence_lengths,
                transition_params=self.transitions,
            )
            return tf.reduce_mean(-log_likelihood)
        return crf_loss

    @property
    def accuracy(self):
        def viterbi_accuracy(y_true, y_pred):
            # -1e10 to avoid zero at sum(mask)
            mask = K.cast(
                K.all(K.greater(y_pred, -1e10), axis=2), K.floatx())
            shape = tf.shape(y_pred)
            sequence_lengths = tf.ones(shape[0], dtype=tf.int32) * (shape[1])
            y_pred, _ = crf_decode(y_pred, self.transitions, sequence_lengths)
            if self.sparse_target:
                y_true = K.argmax(y_true, 2)
            y_pred = K.cast(y_pred, 'int32')
            y_true = K.cast(y_true, 'int32')
            corrects = K.cast(K.equal(y_true, y_pred), K.floatx())
            return K.sum(corrects * mask) / K.sum(mask)
        return viterbi_accuracy

    def compute_output_shape(self, input_shape):
        tf.TensorShape(input_shape).assert_has_rank(3)
        return input_shape[:2] + (self.output_dim,)

    def get_config(self):
        config = {
            'output_dim': self.output_dim,
            'sparse_target': self.sparse_target,
            'supports_masking': self.supports_masking,
            'transitions': K.eval(self.transitions)
        }
        base_config = super(CRF, self).get_config()
        return dict(base_config, **config)

In [26]:
#Model architecture 

num_labels = 3
embedding_size = 768
hidden_size = 4

#Input the word embeddings  
input= Input(shape=(max_len,))
model = Embedding(vocab_size, embedding_size, trainable=False,weights=[embedding_matrix],input_length=max_len)(input)
model = Dropout(0.55)(model)
model = Bidirectional(LSTM(64, return_sequences=True))(model)

#Input the document embedding matrix
auxiliary_input = Input(shape=(200,768))
x = Bidirectional(LSTM(hidden_size, return_sequences=True))(auxiliary_input)
x = Dropout(0.65)(x)

#Calculate the alignment score and apply softmax
attention = dot([model, x], axes=[1, 1])
attention = Activation('softmax')(attention)
context = dot([attention, model], axes=[1,2])
context = BatchNormalization(momentum=0.6)(context)
context = K.permute_dimensions(context, (0,2,1))

#Concatenate the contex vector and the word inputs 
concated = concatenate([model, context])
main_output = Dense(3)(concated)

#Pass output through CRF layer 
crf = CRF(n_tags) 
out = crf(main_output)

model = Model(inputs=[input, auxiliary_input], outputs=out)
adam = optimizers.Adam(lr = 0.001)
model.compile(adam,loss=crf.loss, metrics=[crf.accuracy])

In [27]:
earlyStopping = EarlyStopping(monitor='val_loss', patience=50, verbose=0, mode='min')
mcp_save = ModelCheckpoint('keywords_SL.hdf5', save_weights_only=True,save_best_only=True, monitor='val_loss', mode='min')
reduce_lr_loss = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=20, verbose=1, mode='min', min_lr=0.00001)

In [34]:
num_epochs = 300
history = model.fit([np.array(X_train),np.array(train_matrix)], np.asarray(y_train), batch_size=128, epochs=num_epochs, validation_split=0.15, verbose=1, callbacks = [earlyStopping, mcp_save,reduce_lr_loss])

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


In [35]:
test_pred = model.predict([np.array(X_test),np.array(test_matrix)], verbose=1)
idx2tag = {i: w for w, i in tag2idx.items()}

def pred2label(pred):
    out = []
    for pred_i in pred:
        out_i = []
        for p in pred_i:
            p_i = np.argmax(p)
            out_i.append(idx2tag[p_i].replace("PAD", "O"))
        out.append(out_i)
    return out
    
pred_labels = pred2label(test_pred)
test_labels = pred2label(y_test)

print(classification_report(test_labels, pred_labels))

           precision    recall  f1-score   support

      KEY       0.74      0.68      0.71     12568

micro avg       0.74      0.68      0.71     12568
macro avg       0.74      0.68      0.71     12568

