# NERNN: Named Entity Recognition with Word Embeddings and Char RNNs
- Use a combination of word embeddings and character embeddings + RNN to predict whether a word is a named entity

In [1]:
# Imports
import os
import gzip
import glob

# Computational imports
import numpy as np
import pandas as pd
import tensorflow as tf
tf.reset_default_graph()

# Keras imports
import keras
import keras.backend as K
from keras.preprocessing import sequence
from keras.models import Sequential, Model
from keras.layers import Layer, Input, Embedding, merge, Dense, GRU, TimeDistributed, Layer, Bidirectional
from model_saver import ModelCheckpointBatch

# gensim imports
import gensim
from gensim.models import Word2Vec

%load_ext autoreload
%autoreload

Using TensorFlow backend.


# I. Define Model Architecture

### Model Vars

In [15]:
MODEL_SAVE_PATH = 'models/weights.{epoch:02d}-{batch:02d}.hdf5'

## a. Defining the Keras model architecture

In [16]:
# Going to need a custom layer for selecting the end of words in the character RNN
class GatherLayer(Layer):
    '''
    Scans over the batch to gather specific indices along the time axis
    '''
    def __init__(self, **kwargs):
        super(GatherLayer, self).__init__(**kwargs)
    
    def build(self, input_shape):
        super(GatherLayer, self).build(input_shape)
        
    def compute_mask(self, x, mask=None):
        '''
        Compute the mask
        '''
        return K.cast(K.not_equal(x[1], -1), 'bool')
    
    def call(self, inputs, mask=None):
        '''
        First input is the rnn out (batch_size, max_word_steps, char_lstm_dim)
        Second input is indicies to gather (batch_size, max_word_steps)
        '''
        rnn_inp = inputs[0]
        ind_inp = inputs[1]
        
        ind_inp_zeroed = tf.select(tf.not_equal(ind_inp, -1), ind_inp, tf.zeros_like(ind_inp, dtype='int64'))
        
        def f(inp):
            '''
            Gathers the inds for the input mat of (max_char_len, char_lstm_dim)
            '''
            mat = inp[0]
            inds = inp[1]
            return tf.gather(mat, inds)
        
        map_fn_out = tf.map_fn(f, elems=(rnn_inp, ind_inp_zeroed), dtype='float32')
        
        return map_fn_out
    
    def get_output_shape_for(self, input_shape):
        rnn_shape = input_shape[0]
        ind_shape = input_shape[1]
        return (rnn_shape[0], ind_shape[1], rnn_shape[2])

In [17]:
class SegmentLayer(Layer):
    '''
    Takes a segmented sum
    '''
    def __init__(self, seg_func_name='sum', **kwargs):
        super(SegmentLayer, self).__init__(**kwargs)
        if seg_func_name == 'sum':
            self.seg_func = tf.segment_sum
        elif seg_func_name == 'mean':
            self.seg_func = tf.segment_mean
        elif seg_func_name == 'max':
            self.seg_func = tf.segment_max
        else:
            self.seg_func = tf.segment_sum
        
    def build(self, input_shape):
        super(SegmentLayer, self).build(input_shape)
    
    def compute_mask(self, x, mask=None):
        return K.cast(K.not_equal(x[1], -1), 'bool')
    
    def call(self, x, mask=None):
        rnn_inp = x[0]
        word_end_idx = x[1]
        segment_mask = x[2]

        # Need the max doc len
        max_word_len = tf.shape(word_end_idx)[1]
        
        def f(inp):
            '''
            Performs a segmented sum on each input of (max_char_len, char_lstm_dim)
            '''
            mat = inp[0]
            seg_mask = inp[1]
            
            # perform the segmented sum along the first axis
            seg_sum = self.seg_func(mat, seg_mask)[:-1]  # don't want the last segment
            
            # need to pad the result such that we always have vectors of max_token_len
            seg_sum_shape = tf.shape(seg_sum)
            zero_pad = tf.zeros((max_word_len - seg_sum_shape[0], seg_sum_shape[1]), dtype='float32')
            seg_sum_padded = tf.concat(0, [seg_sum, zero_pad])
            
            return seg_sum_padded
        
        map_fn_out = tf.map_fn(f, elems=(rnn_inp, segment_mask), dtype='float32')
        
        return map_fn_out
    
    def get_output_shape_for(self, input_shape):
        rnn_shape = input_shape[0]
        ind_shape = input_shape[1]
        
        return (rnn_shape[0], ind_shape[1], rnn_shape[2])

In [18]:
# Define a method for setting the models word_embedding layer to have pretrained embeddings
def set_embeddings(embedding, weights, token_map):
    '''
    Takes in a gensim Word2Vec model and our keras model and then adapts the word_embeddings
    in our model to the pre-trained vectors from the w2v model if they exist. Otherwise, they are
    left to the original initalization
    
    Paramters
    =========
    w2v_model : gensim.models.word2vec.Word2Vec
        Word2Vec model that is already loaded with pre-trained embedings
    keras_model : keras.engine.training.Model
        Keras model that has been constructed such that the word embedding layer has the same
        number of embedding dimensions as the pre-trained embeddings
    token_map : dict
        map from token to index in the vocab
    '''    
    for token, ind in token_map.iteritems():
        try:
            pre_trained_emb = w2v_model[token]     
        except:
            pre_trained_emb = weights[ind]
        weights[ind] = pre_trained_emb
    return weights

In [19]:
def construct_model(word_vocab_size, char_vocab_size,
                    w_emb_dim=100, w_lstm_dim=128, 
                    c_emb_dim=100, c_lstm_dim=128, 
                    embedding=None, token2idx=None, trainable_embedding=True, dropout=0.5, 
                    char_emb=True):
    '''
    Constructs the NERNN in keras / tensorflow
    
    Parameters
    ==========
    word_vocab_size : int
        size of the token vocabulary
    char_vocab_size : int
        size of the character vocabulary
    w_emb_dim : int
        size of token lstm layer
    c_emb_dim : int
        size of the char lstm layer
    embedding : optional, gensim.model.word2vec.Word2Vec
        pre-trained embeddings passed in the form of a gensim model
    token2idx : optional, dict
        map from tokens to indices
    trainable_embedding : bool
        indicates whether word embeddings are trainable
    chat_emb : bool
        indicates whether to use character embeddings
    '''
    
    # =========================== #
    # 1. Construct word Embedding #
    # =========================== #

    word_model_in = Input(shape=(None,), name='word_enc_in')
    
    if embedding:
        
        # Make some assertions
        assert token2idx is not None, "Must pass a map from token to idx for current training set with an embedding"
        assert w_emb_dim == embedding.vector_size, "Word embedding dimension must be same as passed embedding"
        
        # Set the weights to the pre-trained weights from the passed embedding
        weights = np.random.uniform(-0.05, 0.05, size=(word_vocab_size+2, w_emb_dim))
        weights = set_embeddings(embedding, weights, token2idx)
        
        # Construct the embedding layer
        word_emb = Embedding(input_dim=word_vocab_size+2, output_dim=w_emb_dim, 
                             input_length=None, mask_zero=True, 
                             weights=[weights], name='word_embedding')
        
    else:
        # Otherwise randomly initialize weights
        word_emb = Embedding(input_dim=word_vocab_size+2, output_dim=w_emb_dim,
                             input_length=None, mask_zero=True, name='word_embedding')

    word_out = word_emb(word_model_in)
    word_emb.trainable = trainable_embedding

    # ============================== #
    # 2. Construct the character RNN #
    # ============================== #

    if char_emb:
        char_model_in = Input(shape=(None,), name='char_enc_in')
        char_emb_out = Embedding(input_dim=char_vocab_size+2, output_dim=c_emb_dim, 
                                 input_length=None, mask_zero=True, name='char_embedding')(char_model_in)
        char_out = Bidirectional(GRU(c_lstm_dim, return_sequences=True, dropout_W=dropout))(char_emb_out)

    # ====================================== #
    # 3. Merge the Word RNN and the Char RNN #
    # ====================================== #

    # If we want to use character embeddings then extract the character representations from the
    # character RNN
    if char_emb:
        # Create an input for the matrix of word end indices
        inds = Input(shape=(None,), dtype='int64', name='token_end_idx')
        
        # Create an input for the matrix of segmentation masks
        seg_mask = Input(shape=(None,), dtype='int64', name='char_seg_mask')

        # Slice the character model out
#         char_model_slice = GatherLayer()([temp_out, inds])
        
        # Segment the character model
        char_model_seg = SegmentLayer(seg_func_name='mean')([char_out, inds, seg_mask])

        # Concatenate the outputs of the word model and the sliced character model
        merge_out = merge([word_out, char_model_seg], mode='concat', concat_axis=2)
    else:
        merge_out = word_out

    # Add Bidirectional lstm here
    gru_out = Bidirectional(GRU(w_emb_dim, return_sequences=True, dropout_W=dropout))(merge_out)
    
    # ================== #
    # 4. Compute Output  #
    # ================== #

    # Time distribute a final layers with binary output
    hout = TimeDistributed(Dense(w_lstm_dim, activation='tanh'))(gru_out)
    fout = TimeDistributed(Dense(1, activation='sigmoid'))(hout)
    
    model = Model([word_model_in, char_model_in, inds, seg_mask], output=[fout])
    
    # ====================================== #
    # 5. Define Accuracy Metrics and Compile #
    # ====================================== #

    def accuracy(y_true, y_pred):
        den = K.sum(K.cast(K.not_equal(y_true, -1), dtype='float32'))
        num = K.sum(K.cast(K.equal(y_true, K.round(y_pred)) & K.not_equal(y_true, -1), dtype='float32'))
        return num / den
    
    def true_pos(y_true, y_pred):
        den = tf.reduce_sum(tf.cast(tf.equal(tf.round(y_pred), 1) & tf.not_equal(y_true, -1), dtype='float32'))
        
        i = tf.equal(y_true, tf.round(y_pred)) & \
                tf.equal(1., tf.round(y_pred)) & \
                tf.not_equal(y_true, -1)
                
        num = tf.reduce_sum(tf.cast(i, dtype='float32'))

        frac = tf.select(den==0., 0., num / den)
        
        return frac
    
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[accuracy, true_pos])
    
    return model

## b. Construct a gensim model that we can use to access underlying embeddings from the massive GoogleNews embedding matrix
- 3 Million words in GoogleNews Vocab with 300 dimensional embeddings

In [74]:
# Load in the gensim word2vec model
w2vmodel = Word2Vec.load_word2vec_format('data/GoogleNews-vectors-negative300.bin', binary=True)

# II. Bring in Wikipedia Data to Train On
- Downloaded WikiNER data into data directory
- Download WikiGold data into data directory
- Will read in the whole WikiNER dataset into one pandas df
- Will read in pre-split train/test/dev data into pandas for WikiGold
- Preprocess the data to only tag PER (peoples names)
- Will batch sentences together into 'Documents' for WikiGold

### Utility Functions

In [5]:
import re

def untokenize(words):
    """
    Untokenizing a text undoes the tokenizing operation, restoring
    punctuation and spaces to the places that people expect them to be.
    Ideally, `untokenize(tokenize(text))` should be identical to `text`,
    except for line breaks.
    """
    text = ' '.join(words)
    step1 = text.replace("`` ", '"').replace(" ''", '"').replace('. . .',  '...')
    step2 = step1.replace(" ( ", " (").replace(" ) ", ") ")
    step3 = re.sub(r' ([.,:;?!%]+)([ \'"`])', r"\1\2", step2)
    step4 = re.sub(r' ([.,:;?!%]+)$', r"\1", step3)
    step5 = step4.replace(" '", "'").replace(" n't", "n't").replace(
         "can not", "cannot")
    step6 = step5.replace(" ` ", " '")
    return step6.strip()

def prepend_zeros(arr, num_zeros=1, dtype='int64'):
    '''
    Takes a 1-D numpy array and prepends the specified number of zeros
    '''
    zs = np.zeros(shape=(num_zeros,), dtype=dtype)
    return np.concatenate([zs, arr], axis=0)

def spans(txt, tokens):
    '''
    Takes the original (read: "untokenized" text) and the tokens and returns a list of word
    end indices.
    
    Parameters
    ==========
    txt : string
        untokenized / raw string we want to index the tokens into
    tokens : list, array
        list of tokens that make up the txt
        
    Returns
    =======
    word_inds : list
        list of word span indices for the tokens into the txt
    '''
    word_inds = []
    offset = 0
    for token in tokens:
        offset = txt.find(token, offset)
        word_inds.append((offset, offset+len(token)))
        offset += len(token)
    return word_inds

def segmentation_mask(word_end_inds):
    '''
    Takes end (inclusive) index of the tokens and returns a mask with length num of characters
    that masks each token with a monotonically increasing index
    
    Example
    -------
    ['H','e','l','l','o',' ','W','o','r','l','d'] -> [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
    '''
    doc_len = max(word_end_inds) + 1
    
    word_inds = [-1] + word_end_inds
    
    seg_mask = np.zeros(doc_len, dtype='int64')
    
    for i in range(len(word_inds) - 1):
        s = word_inds[i]
        e = word_inds[i+1]
        seg_mask[s+1:e+1] = i
    return seg_mask 

## a. Methods for Reading in Data

In [6]:
# ==== #
# Vars #
# ==== #

# Paths
DATA_DIR = './data'
WIKIGOLD_DATA = os.path.join(DATA_DIR, 'WikiGold')
WIKINER_DATA = os.path.join(DATA_DIR, 'raw', 'aij-wikiner-en-wp3')

# Misc.
DOCSTART_TAG = '-DOCSTART-'

# ======================================= #
# Methods for Reading and Processing Data #
# ======================================= #

# Read in the data only (no processing except putting in in pandas)
def read_wikigold_datasets(data_dir):
    '''
    Reads in the wikipedia Gold datasets from the data directory.
    These datasets were extracted using scripts in
    https://github.palantir.build/DeltaSierra/tagtrex
    
    Paramters
    =========
    data_dir : str
        path to the data directory
        
    Returns
    =======
    train_df : pandas.core.frame.DataFrame
        dataframe with training data
    test_df : pandas.core.frame.DataFrame
        dataframe with test data
    dev_df : pandas.core.frame.DataFrame
        dataframe with dev data
    TODO: (areiner) what the hell is dev data?... I didn't do these splits
    '''
    
    dataset_names = ('train', 'test', 'dev')
    datasets = []  # ordered train, test, dev
    for dname in dataset_names: 
        found_datasets = glob.glob(data_dir + '/*' + dname + '*.pkl')
        if len(found_datasets)==0:
            print "No dataset with name {} found".format(dname)
        elif len(found_datasets) > 1:
            print "Multiple dataset with name {} found".format(dname)
        else:
            df = pd.read_pickle(found_datasets[0])
            datasets.append(df)
    return datasets

def read_wikiner_dataset(data_path):
    '''
    Reads in wikiner data in its raw text form and outputs a dataframe
    '''
    # Read in each line and process it
    with open(data_path, 'r') as f:
        
        # Track the list of tokens, tags
        token_lists = []
        tags_lists = []
        
        for line in f:
            # Current sample tokens / tags
            token_list = []
            tag_list = []
            
            # Processing
            l = line.strip()
            aug_tokens = l.split(' ')
            
            # Discard the POS
            for at in aug_tokens:
                at_split = at.split('|')
                token_list.append(at_split[0])
                tag_list.append(at_split[-1])
            
            # Add the sample
            token_lists.append(token_list)
            tags_lists.append(tag_list)
            
    df = pd.DataFrame({'document': token_lists[1:], 'tags': tags_lists[1:]})
    return df
    
# Method to filter all tags such that it is only 'O' or 'PER'
def modify_tags(df):
    '''
    Takes in a dataframe with columns "sentence" and "tags" and modifies
    the "tags" column to convert everything that is not "B-PER" to "O" and
    converts "B-PER" to "PER"
    
    Parameters
    ==========
    df : pandas.core.frame.DataFrame
        input dataframe of sentences and tags (with columns "sentence" and "tags")
    '''
    def modify_taglist(taglist):
        '''
        Function to apply to each array of tags in each cell of the "tags"
        column in the dataframe
        '''
        m = {'B-PER': 'PER', 'I-PER': 'PER'}
        return map(lambda v: m.get(v, 'O'), taglist)

    df['tags'] = df['tags'].apply(modify_taglist)
    
    return df

# Method to find the index for document separations
def calc_docstart_inds(df):
    '''
    Takes in a dataframe and looks for specific document start tags
    '''
    return df.index[df.sentence.map(lambda x: DOCSTART_TAG in x)].get_values()

# Method to concatenate sentence tags to document tags
def construct_doc_tag(dfs, max_sent_per_doc=None):
    '''
    Takes a DataFrame slice and returns a dataframe of tokens and tags that
    concatenates all sentence tokens and tags for the whole dataframe, potentially
    into groups of size specified by max_sent_per_doc size
    '''
    new_df = pd.DataFrame(columns=['document', 'tags'])
    
    sentences = []
    tags = []
    
    for i, (_, row) in enumerate(dfs.iterrows()):
        
        if (max_sent_per_doc is not None) and i!=0 and (i % max_sent_per_doc == 0):
            new_df = new_df.append({'document': sentences, 'tags': tags}, ignore_index=True)
            sentences = []
            tags = []

        tokens = row.sentence
        tagseq = row.tags
        
        sentences.extend(tokens)
        tags.extend(tagseq)
    
    new_df = new_df.append({'document': sentences, 'tags': tags}, ignore_index=True)
    
    return new_df

# Method to cluster sentence token dataframes into documents
def transform_sent_to_docs(df, max_sent_per_doc=None):
    '''
    Transform a DataFrame of lists of tokens and tags per sentence into
    a DataFrame of "Documents" and the tags for that document.
    At a minimum, we split by document length as per the wikipedia page
    '''
    doc_df = pd.DataFrame(columns=['document', 'tags'])
    
    # 1. Calculate the indices of Document starts
    doc_starts = prepend_zeros(calc_docstart_inds(df))
    
    # 2. Slice the df for the sentences in each document
    for i, (start, end) in enumerate(zip(doc_starts[:-1], doc_starts[1:])):
        if i != 0:
            start += 1
        dfs = df.iloc[start:end]
        inc_doc_df = construct_doc_tag(dfs, max_sent_per_doc=max_sent_per_doc)
        doc_df = doc_df.append(inc_doc_df)
    
    doc_df.reset_index(inplace=True, drop=True)
    
    return doc_df
        
# Method that adds a column which contains the untokenized sentence / document
def untokenize_column(df, col='document', new_col='document_string'):
    '''
    Untokenizes a column of lists of tokens. The intended use in the pipeline is to
    construct untokenized strings after clustering sentences into documents or partial
    documents.
    '''
    df[new_col] = df[col].map(untokenize)
    return df

# Method that adds a column which contains the index of the tokens into the untokenized sentence / document
def word_index_columns(df, doc_col='document_string', tok_col='document'):
    '''
    Takes a column of strings (either sentences or documents as long as it's one continuous string) 
    and creats two new columns a 'word_start_inds' column that contains the start indices of all tokens
    and a 'word_end_inds' that contains the end indices of all tokens
    '''
    # Calculate the word spanning indices
    word_inds = df.apply(lambda r: spans(r[doc_col], r[tok_col]), axis=1)
    word_start_inds = word_inds.apply(lambda v: [e[0] for e in v])
    word_end_inds = word_inds.apply(lambda v: [e[1] - 1 for e in v])
    word_segment_mask = word_end_inds.map(segmentation_mask)
    
    # Insert new columns
    df['word_start_inds'] = word_start_inds
    df['word_end_inds'] = word_end_inds
    df['word_segment_mask'] = word_segment_mask
    return df

## b. Methods for data preprocessing

In [7]:
# Construct word and character maps
def construct_map(element_lists, vocab_size=None):
    '''
    Constructs a vocabulary from 
    '''
    from collections import Counter
    c = Counter()
    for els in element_lists:
        c.update(els)

    if vocab_size is not None:
        most_common = [x[0] for x in c.most_common(vocab_size)]
        hash_map = dict(zip(most_common, range(1, len(most_common)+1)))
    else:
        hash_map = dict(zip(c.keys(), range(1, len(c)+1)))

    return hash_map

def reverse_map(m):
    m_inv = dict(((ind, k) for k, v in m.iteritems()))
    return m_inv
    
def character_column(df):
    '''
    Inserts a column of the individual characters into the dataframe 
    
    Parameters
    ==========
    df : pandas.core.frame.DataFrame
        dataframe containing the string / untokenized documents
        
    Returns
    =======
    new_df : pandas.core.frame.DataFrame
        Updated df
    '''
    df['chars'] = df['document_string'].map(lambda v: list(v))
    return df

def construct_word_char_maps(df, vocab_size=None, return_inv_dicts=False):
    '''
    Construct the word and character maps from the dataframe
    
    Parameters
    ==========
    df : pandas DataFrame
        dataframe with tokens and character lists already constructed
    
    Returns
    =======
    token_map : dict
    char_map : dict
    '''
    token_map = construct_map(df['document'])
    char_map = construct_map(df['chars'])
    
    if return_inv_dicts:
        token_map_inv = reverse_map(token_map)
        char_map_inv = reverse_map(char_map)
        
        return token_map, char_map, token_map_inv, char_map_inv
    else:
        return token_map, char_map

def encode_strings(element_lists, hash_map):
    '''
    Encode the element_list in terms of integers
    NOTE: 0 is reserved for masking
    '''
    new_element_list = []
    for els in element_lists:
        new_els = map(lambda x: hash_map.get(x, len(hash_map)+1), els)
        new_element_list.append(new_els)
    return new_element_list
    
def encode_tokens_chars(df, token_map, char_map):
    '''
    Encode the tokens and characters
    '''
    df['token_enc'] = encode_strings(df['document'], token_map)
    df['char_enc'] = encode_strings(df['chars'], char_map)
    return df

def _encode_tags(tag_list):
    '''
    Encode a taglist
    '''
    def _tag_str_to_int(tag_str):
        if tag_str == 'PER':
            return 1
        else:
            return 0
        
    tags_enc = map(_tag_str_to_int, tag_list)
    return tags_enc
    
def encode_tags(df):
    '''
    Encode the tags as binary outcomes
    
    Parameters
    =========
    df : pandas.core.frame.DataFrame
        dataframe with character tags that are in the set {'O', and 'PER'}
    '''
    
    df['tags_enc'] = df['tags'].apply(_encode_tags)
    return df

def remove_pure_neg(df, frac=0.1):
    '''
    Data processing function that removes a fraction of the samples that contain no positive tags
    '''
    
    if frac == 0.:
        return df
    
    # Extract the samples with positive tags and a fraction of those with no positive tags
    nonz_idx = df.index[df.tags_enc.map(sum)!=0]
    z_idx = df.index[df.tags_enc.map(sum)==0]
    z_idx_red = z_idx[:len(z_idx)*(1-frac)]
    new_idx = z_idx_red.append(nonz_idx)
    
    # Select the new index and shuffle the samples
    df_new = df.loc[new_idx].sample(frac=1.).reset_index()
    
    return df_new
    
def _construct_character_tags(word_start_inds, word_end_inds, char_list, tag_enc):
    '''
    
    '''
    char_tags = np.zeros(len(char_list), dtype='int64')
    
    word_start_inds = np.array(word_start_inds)
    word_end_inds = np.array(word_end_inds)
    tag_enc = np.array(tag_enc)
    
    word_start_ind_tag = word_start_inds[tag_enc==1]
    word_end_ind_tag = word_end_inds[tag_enc==1]
    word_span_tag = zip(word_start_ind_tag, word_end_ind_tag)
    
    for (s, e) in word_span_tag:
        char_tags[s:e+1] = 1
    
    return list(char_tags)

def construct_character_tags(df):
    df['char_tag_mask'] = df.apply(lambda r: _construct_character_tags(r.word_start_inds, r.word_end_inds, 
                                                                       r.chars, r.tags_enc), axis=1)
    return df

## c. Full data processing pipeline function

In [8]:
def process_wikigold_dataset(df, max_sent_per_doc=None, rem_neg_frac=0.0):
    
    # Text Processing
    df = modify_tags(df)
    df = transform_sent_to_docs(df, max_sent_per_doc=max_sent_per_doc) # Concatenate sentence tokens into topics
    df = untokenize_column(df)  # (approximately) concatenate the tokens into documents
    df = word_index_columns(df)  # index the tokens into the untokenized strings
    df = character_column(df)

    # Encoding work
    tm, cm = construct_word_char_maps(df)  # first construct the hash maps
    df = encode_tokens_chars(df, tm, cm)  # encode the tokens and the characters
    df = encode_tags(df)  # encode the tags as binary outcomes 'O' -> 0 and 'PER' -> 1
    df = remove_pure_neg(df, frac=rem_neg_frac)  # remove a fraction of negative samples
    
    # Extra info for model validation
    df = construct_character_tags(df)
    
    return df, tm, cm

def run_wikigold_data_pipeline(data_path, max_sent_per_doc=None, rem_neg_frac=0.0):
    '''
    Runs the entire transformation pipeline for WikiGold data.
    
    Pipeline Steps
    --------------
    1. Read in wiki datasets (train, test, dev)
    2. Modify the IOB tags to only 'O' and 'PER'
    For each dataset in the wiki datasets:
        3. Take sentence examples and concatenate them into individual documents
        4. Take the documents which consist of lists of tokens and "untokenize" them into continuous strings
        5. Take the tokens and index them into the untokenized text so we have start and end indices for each token
        6. Split up the untokenized text into a list of characters
        7. Construct element -> monotonically increasing index map for both tokens and characters
        8. Encode the tokens and characters by their index in their respective maps
        9. Encode the tags as binary outcomes 'O' -> 0 and 'PER' -> 1
    '''
    # 1. Reading in data
    wiki_datasets = read_wikigold_datasets(data_path)

    wiki_datasets_processed = []
    wiki_datasets_maps = []
    # 2. Run processing on each dataframe
    for df in wiki_datasets:
        new_df, token_map, char_map = process_wikigold_dataset(
            df,
            max_sent_per_doc=max_sent_per_doc,
            rem_neg_frac=rem_neg_frac
        )
        wiki_datasets_processed.append(new_df)
        wiki_datasets_maps.append((token_map, char_map))
    
    return wiki_datasets_processed, wiki_datasets_maps

def process_wikiner_dataset(df, rem_neg_frac=0.0):
    # Text Processing
    df = modify_tags(df)
    df = untokenize_column(df)  # (approximately) concatenate the tokens into documents
    df = word_index_columns(df)  # index the tokens into the untokenized strings
    df = character_column(df)

    # Encoding work
    tm, cm = construct_word_char_maps(df)  # first construct the hash maps
    df = encode_tokens_chars(df, tm, cm)  # encode the tokens and the characters
    df = encode_tags(df)  # encode the tags as binary outcomes 'O' -> 0 and 'PER' -> 1
    df = remove_pure_neg(df, frac=rem_neg_frac)  # remove a fraction of negative tags
    
    # Extra info for model validation
    df = construct_character_tags(df)
    
    return df, tm, cm

def run_wikiner_data_pipeline(data_path, rem_neg_frac=0.0):
    df = read_wikiner_dataset(data_path)
    df, tm, cm = process_wikiner_dataset(df, rem_neg_frac=rem_neg_frac)
    return df, tm, cm

In [9]:
# Load WikiGold datasets
(train, test, dev), ((tm_train, cm_train), (tm_test, cm_test), (tm_dev, cm_dev)) = \
                                run_wikigold_data_pipeline(WIKIGOLD_DATA,
                                                           max_sent_per_doc=5,
                                                           rem_neg_frac=0.0)

In [265]:
# Load WikiNER datasets
wikiner, tm_ner, cm_ner = run_wikiner_data_pipeline(WIKINER_DATA, rem_neg_frac=0.7)

## d. Final padding

In [10]:
def pad_encoding_column(col, value=0):
    '''
    Takes a pandas series of lists of encodings and returns a single matrix with padded results
    '''    
    col_padded = sequence.pad_sequences(col.tolist(), padding='post', value=value)
    return col_padded

def generate_padded_data(df):
    '''
    Take the dataframe and return numpy matrices with padded encodings
    '''
    cols_to_pad = ['token_enc', 'char_enc', 'tags_enc', 
                   'word_start_inds', 'word_end_inds', 'word_segment_mask']
    padded_data = []
    max_doc_len = max(df.document.map(len))

    for col in cols_to_pad:
        if ('tag' in col) or ('inds' in col):
            col_pad = pad_encoding_column(df[col], value=-1)
        elif 'segment' in col:
            col_pad = segment_mask_padding2(df[col])
        else:
            col_pad = pad_encoding_column(df[col])
        padded_data.append(col_pad)
    return padded_data

def segment_mask_padding(seg_col, max_doc_len):
    col_pad = pad_encoding_column(seg_col, value=-1)
    
    max_char_len = col_pad.shape[1]
    
    for i, r in enumerate(col_pad):
        last_word_idx = max(r)
        idx_last_word = np.argmax(r)
        mask_inds = (r == -1)
        
        # from idx of last mask + 1 put enough unique values so there are as many unique values as 
        # the maximum number of unique tokens
        total_pad = np.zeros(mask_inds.sum(), dtype='int64')
        fake_unique = np.arange(last_word_idx + 1, max_doc_len)
        total_pad[:len(fake_unique)] = fake_unique
        
        # then fill in the remaing open slots with the last value
        total_pad[len(fake_unique):] = fake_unique[:][-1:]
        r[mask_inds] = total_pad
    return col_pad

def segment_mask_padding2(seg_col):
    col_pad = pad_encoding_column(seg_col, value=-1)
    
    for r in col_pad:
        last_word_idx = max(r)
        mask_inds = (r == -1)
        
        r[mask_inds] = last_word_idx + 1
    return col_pad

def unk_insertion(token_enc, tags_enc, unk_idx, unk_neg_prob=0.02, unk_pos_prob=0.1):
    '''
    Imparts unknown tags into the dataset
    '''
    
    # If the probablities are set to None then there is no unknown insertion
    if (unk_neg_prob is None) and (unk_pos_prob is None):
        return token_enc
    
    # Get a selection mask for the negative and positive tags
    tags_neg = tags_enc == 0
    tags_pos = tags_enc == 1
    unk_neg_replace_mask = (np.random.rand(token_enc.shape[0], token_enc.shape[1]) < unk_neg_prob) & tags_neg
    unk_pos_replace_mask = (np.random.rand(token_enc.shape[0], token_enc.shape[1]) < unk_pos_prob) & tags_pos
    
    token_enc_unk = token_enc[:]
    token_enc_unk[unk_neg_replace_mask] = unk_idx
    token_enc_unk[unk_pos_replace_mask] = unk_idx
    
    return token_enc_unk

def data_generator(df, batch_size=20, nb_epoch=10, shuffle=True, unk_neg_prob=None, unk_pos_prob=None):
    '''
    A python generator that yields shuffled batches
    '''
    # Generate the full padded numpy arrays from the dataframe
    token_enc, char_enc, tags_enc, word_start_ind, word_end_ind, word_segment_mask = generate_padded_data(df)
    
    for _ in range(nb_epoch):
        
        # Apply random unk insertion
        if (unk_neg_prob is not None) or (unk_pos_prob is not None):
            unk_idx = df.token_enc.map(max).max() + 1
            token_enc_unk = unk_insertion(token_enc, tags_enc, unk_idx,
                                          unk_neg_prob=unk_neg_prob, unk_pos_prob=unk_pos_prob)
            vars = [token_enc_unk, char_enc, tags_enc, word_start_ind, word_end_ind, word_segment_mask]
        else:
            vars = [token_enc, char_enc, tags_enc, word_start_ind, word_end_ind, word_segment_mask]

        # Shuffle the index
        idx = np.arange(token_enc.shape[0])
        if shuffle:
            np.random.shuffle(idx)

        # Determine number of batches in the epoch
        num_batches = (token_enc.shape[0] / batch_size) + 1

        for i in range(num_batches):
            idx_batch = idx[i*batch_size:(i+1)*batch_size]
            token_enc_batch, char_enc_batch, \
            tags_enc_batch, word_start_ind_batch, \
            word_end_ind_batch, word_segment_mask_batch = map(lambda v: v[idx_batch], vars)

            inp = [token_enc_batch, char_enc_batch, word_end_ind_batch, word_segment_mask_batch]
            target = np.expand_dims(tags_enc_batch, 2)
            yield (inp, target)

# III. Construct Model and Bring in Pre-trained word emebddings
- For now, we will use GoogleNews embedding vectors with dimensionality 300
- We will not pre-train character emebeddings for now
    - Why? bc I haven't found character embedings yet and I would imagine the best representations can different markedly from use case to use case

In [14]:
# Training vocab sizes
word_vocab_size = len(tm_ner)
char_vocab_size = len(cm_ner)

NameError: name 'tm_ner' is not defined

In [11]:
# Training vocab sizes
word_vocab_size = len(tm_train)
char_vocab_size = len(cm_train)

### Local Model

In [20]:
# Local test model construction
test_model = construct_model(word_vocab_size=word_vocab_size, char_vocab_size=char_vocab_size,
                        w_emb_dim=5, w_lstm_dim=16, c_emb_dim=5, c_lstm_dim=8, trainable_embedding=False)



### Real Model (Cloud)

In [78]:
# Read model construction
model = construct_model(word_vocab_size=word_vocab_size, char_vocab_size=char_vocab_size,
                        w_emb_dim=300, w_lstm_dim=100, c_emb_dim=50, c_lstm_dim=100, 
                        embedding=w2vmodel, token2idx=tm_train, trainable_embedding=False)

# VI. Train the Model!

In [21]:
# Generate the data
token_enc_train, char_enc_train, tags_enc_train, word_start_ind_train, word_end_ind_train, word_segment_mask_train = \
                                                                                        generate_padded_data(train)

### Train testing model locally

In [22]:
test_model.fit([token_enc_train, char_enc_train, word_end_ind_train, word_segment_mask_train], 
          np.expand_dims(tags_enc_train, 2), batch_size=2, callbacks=[ModelCheckpointBatch(MODEL_SAVE_PATH, period=2)])

Epoch 1/10
 18/288 [>.............................] - ETA: 805s - loss: 0.6320 - accuracy: 0.9074 - true_pos: nan

KeyboardInterrupt: 

In [26]:
test_model.fit_generator(data_generator(train, batch_size=2, nb_epoch=2, unk_neg_prob=0.05, unk_pos_prob=0.1), 
                         samples_per_epoch=train.shape[0], 
                         nb_epoch=2, callbacks=[ModelCheckpointBatch(MODEL_SAVE_PATH, period=2)])

Epoch 1/2
 24/288 [=>............................] - ETA: 724s - loss: 0.4661 - accuracy: 0.9617 - true_pos: nan

KeyboardInterrupt: 

### Train full model in cloud

In [63]:
model.fit([token_enc_train, char_enc_train, word_end_ind_train, word_segment_mask_train], 
          np.expand_dims(tags_enc_train, 2), batch_size=100)

Epoch 1/10
Epoch 2/10

KeyboardInterrupt: 

In [23]:
model.fit_generator(data_generator(wikiner, batch_size=100, nb_epoch=10, unk_neg_prob=0.05, unk_pos_prob=0.1), 
                    samples_per_epoch=wikiner.shape[0], 
                    nb_epoch=10, callbacks=[ModelCheckpointBatch(MODEL_SAVE_PATH, period=100)])

NameError: name 'model' is not defined

# VII. Model Validation

In [15]:
np.set_printoptions(precision=4, suppress=True)

## a. Methods for validation

**Masks**
- Tokens -> 0
- Characters -> 0
- Tags -> -1
- Word Indices -> -1

In [69]:
def predict_char_tags(word_start_inds, word_end_inds, predicted_tag_enc):
    '''
    Returns lists the same length as the number of characters in the document string
    that encodes characters as 1 when they belong to a tagged entity
    
    Paramters
    =========
    word_start_inds : numpy array
        The padded word start indices for a single example
        
    word_end_inds : numpy array
        The padded word end indices for a single example
    
    predicted_tag_enc : numpy array
        The predicted token-wise tag encoding for a single example with padded (nonsense) predictions left in
        Note: this is the form of the output from the model
    '''
    # 1. Get the length of the document string (num of characters in example) from the word indices
    filt = word_end_inds!=-1
    word_start_inds = word_start_inds[filt]
    word_end_inds = word_end_inds[filt]
    doc_string_len = word_end_inds.max() + 1
    
    # 2. Remove the padded tags
    predicted_tag_enc = predicted_tag_enc[filt]

    # 3. Fill in a numpy array of tags the length of the unpadded document string
    char_tags = np.zeros(doc_string_len, dtype='int64')
    
    word_start_ind_tag = word_start_inds[predicted_tag_enc==1]
    word_end_ind_tag = word_end_inds[predicted_tag_enc==1]
    word_span_tag = zip(word_start_ind_tag, word_end_ind_tag)
    
    for (s, e) in word_span_tag:
        char_tags[s:e+1] = 1
    
    return list(char_tags)
    
def tagged_entities_in_docstring(char_lists,
                                 word_start_inds_arr, word_end_inds_arr,
                                 yhats,
                                 true_char_tag_enc_lists,
                                 thresh=0.5):
    '''
    Takes the model output and returns:
    (1) the array of characters
    (2) the characters that actually belong to tags
    (3) the characters we estimate belong to tags
    
    Parameters
    ==========
    char_lists : nested list
        nested list of character lists for the examples passed to produce the provided
        yhat model output
        
    word_star_inds_arr : numpy array
        padded array of word start indices
    
    word_end_inds_arr : numpy array
        padded array of word end indices
        
    yhats : numpy array
        output of the model
        
    true_char_tag_enc_lists : nested list
        true char by char tag encoding of the model
        
    Returns
    =======
    res : list of 3-tuple
        each tuple corresponds to a single training example / document / partial document
        and is composed of
        (i) sequence of characters that make up the document 
        (ii) sequence of predicted tags for each character representing a mask of what tokens were tagged
        (iii) sequence of true tags for each character representing a mask of that tokens were tagged
    
    NOTE
    ----
    All the inputs should share the same size in the 0th dimension.
    For dim 1, the sizes should be:
        (len of doc i in characters, 
        max document len in tokens,
        max document len in tokens,
        max document len in tokens, 
        len of doc i in characters)
    '''
    res = []
    for i in range(len(word_start_inds_arr)):
        
        # 1. Get each example passed
        
        # Get the predictions
        char_list = char_lists[i]
        word_start_inds = word_start_inds_arr[i]
        word_end_inds = word_end_inds_arr[i]
        yhat = yhats[i]
        predicted_tag_enc =(yhat >= thresh).astype('int64')
        
        # Extract the true character tags
        true_char_tag_enc = true_char_tag_enc_lists[i]
        
        # Calculate the predicted char_tags
        predicted_char_tags = predict_char_tags(word_start_inds, word_end_inds, predicted_tag_enc)
        
        # Append the example results
        res.append((np.array(char_list), np.array(predicted_char_tags), np.array(true_char_tag_enc)))
    
    return res

def add_tagged_char_mask_to_df(model, df, batch_size=5, thresh=0.5):
    
    token_enc, char_enc, tags_enc, word_start_ind, word_end_ind, word_segment_mask = generate_padded_data(df)
    x = [token_enc, char_enc, word_end_ind, word_segment_mask]
    if 'prob' in df.columns:
        y_preds = df.prob.tolist()
    else:
        y_preds = model.predict(x, batch_size=batch_size)[..., 0]
        
    char_lists = df.chars.tolist()
    true_char_tag_enc_lists = df.char_tag_mask.tolist()
    
    res_list = tagged_entities_in_docstring(char_lists,
                                                 word_start_ind, word_end_ind, 
                                                 y_preds,
                                                 true_char_tag_enc_lists, 
                                                 thresh=thresh)
    pred_char_tag = [r[1] for r in res_list]
    df['pred_char_tag_mask'] = pred_char_tag
    return df

def add_predictions_to_df(x, model, df, batch_size=5, thresh=0.5):
    if 'prob' in df.columns:
        y_preds = df.prob.tolist()
    else:
        y_preds = model.predict(x, batch_size=batch_size)[..., 0]
    y_preds_nopad = []
    yhats = []
    for i in range(len(y_preds)):
        token_mask = x[0][i] != 0
        y_pred_nopad = y_preds[i][token_mask]
        y_preds_nopad.append(y_pred_nopad)
        yhats.append((y_pred_nopad > thresh).astype('int32'))
    df['prob'] = y_preds_nopad
    df['yhat'] = yhats
    return df

def calculate_and_add_preds_to_df(model, df, batch_size=5, thresh=0.5):
    '''
    Calculate and add preds
    '''
    token_enc, char_enc, tags_enc, word_start_ind, word_end_ind, word_segment_mask = generate_padded_data(df)
    x = [token_enc, char_enc, word_end_ind, word_segment_mask]
    df = add_predictions_to_df(x, model, df, batch_size=batch_size)
    return df

## b. Run Predictions on the train and test dataframes and store them

In [71]:
train_p = calculate_and_add_preds_to_df(model, train)

KeyboardInterrupt: 

In [70]:
train_p = add_tagged_char_mask_to_df(model, train_p)

NameError: name 'train_p' is not defined

In [47]:
train_p.columns

Index([u'document', u'tags', u'document_string', u'word_start_inds',
       u'word_end_inds', u'chars', u'token_enc', u'char_enc', u'tags_enc',
       u'char_tag_mask', u'prob', u'yhat', u'pred_char_tag_mask'],
      dtype='object')

### Explore Train Example

In [153]:
i = 0
x_in = [token_enc_train[i:i+1], char_enc_train[i:i+1], word_end_ind_train[i:i+1], word_segment_mask_train[i:i+1]]

out = model.predict(x_in)[0, :, 0]

token_ex = np.array(train.document[i])
tags_ex = np.array(train.tags[i])
tags_enc_ex = np.array(train.tags_enc[i])

m = token_ex[(tags_enc_ex.nonzero()[0])]

print m
print 'Real tags:', zip(m, out[tags_enc_ex.nonzero()[0]])
out_sorted = out.argsort()
out_sorted = out_sorted[out_sorted<len(token_ex)]
print "Lowest scores:", token_ex[out_sorted[:30]]
print "Highest scores:", token_ex[out_sorted[-20:]]
print "Highest scores:", out[out_sorted[-20:]]

NameError: name 'model' is not defined

In [76]:
# Test data
token_enc_test, char_enc_test, tags_enc_test, word_start_ind_test, word_end_ind_test = \
                                                                                        generate_padded_data(test)

### Explore Test Example

In [84]:
i = 2
x_in = [token_enc_test[i:i+1], char_enc_test[i:i+1], word_end_ind_test[i:i+1]]

out = model.predict(x_in)[0, :, 0]

token_ex = np.array(test.document[i])
tags_ex = np.array(test.tags[i])
tags_enc_ex = np.array(test.tags_enc[i])

m = token_ex[(tags_enc_ex.nonzero()[0])]

print m
print zip(m, out[tags_enc_ex.nonzero()[0]])
out_sorted = out.argsort()
out_sorted = out_sorted[out_sorted<len(token_ex)]
print "Lowest scores:", token_ex[out_sorted[:30]]
print "Highest scores:", token_ex[out_sorted[-20:]]
print "Highest scores:", out[out_sorted[-20:]]

['Dorothea' 'Dorothea' 'von' 'Schlegel' 'Rahel' 'Levin' 'Henriette' 'Herz'
 'Madame' 'de' 'Sta\xc3\xab' 'Friedrich' 'Philipp' 'Moses' 'Mendelssohn'
 'Immanuel' 'Kant' 'John' 'Locke' 'Alexander' 'Pope' 'Dorothea']
[('Dorothea', 0.00082105485), ('Dorothea', 0.0094426926), ('von', 0.0055673928), ('Schlegel', 0.0051288288), ('Rahel', 0.00068829377), ('Levin', 9.7050888e-06), ('Henriette', 0.00013976262), ('Herz', 2.7702123e-05), ('Madame', 0.010900004), ('de', 0.0044732802), ('Sta\xc3\xab', 0.0014611182), ('Friedrich', 0.52245504), ('Philipp', 0.00066411082), ('Moses', 0.00057833287), ('Mendelssohn', 0.0017986018), ('Immanuel', 0.00022611055), ('Kant', 9.4189062e-07), ('John', 7.342115e-05), ('Locke', 7.0163347e-05), ('Alexander', 0.0013022374), ('Pope', 0.051861312), ('Dorothea', 0.01213771)]
Lowest scores: ['Kant' 'and' 'and' 'adopted' 'and' 'greatest' 'and' 'critics' 'of'
 'translator' 'Levin' ',' 'novelists' 'to' 'musicians' ',' 'convert'
 'leading' ',' 'Herz' ',' ',' 'as' ',' 'medieva

# Old Stuff

## Q: How can we determine the end of word indices from the tokens?
- we need to use the provided tokens and reconstruct the original string
    - We do this via an "untokenize" function found off the shelf online
    - If we tokenize again (using nltk word_tokenize) do we get the same result?   

In [312]:
print word_tokenize(res.iloc[4].document_string)[257]
print res.iloc[4].document[257]
print untokenize(res.iloc[4].document[254:260])

def untokenize2(tokens):
    import string
    return "".join([" "+i if not i.startswith("'") and i not in string.punctuation else i for i in tokens]).strip()
untokenize2(res.iloc[4].document[254:270])

``
"
bonus tracks: " Battle of


'bonus tracks:" Battle of One"( an original song that was also set'

## Q: Why is the true positive metric for the keras model messing up and returning nan?
- Extracting the variables from the tensorflow model and constructing the accuracy metric outside works...

In [136]:
y_true = tf.placeholder(dtype='float32')
y_pred = model.output

def true_pos(y_true, y_pred):
    den = tf.reduce_sum(tf.cast(tf.equal(tf.round(y_pred), 1) & tf.not_equal(y_true, -1), dtype='float32'))

    i = tf.equal(y_true, tf.round(y_pred)) & \
            tf.equal(1., tf.round(y_pred)) & \
            tf.not_equal(y_true, -1)

    num = tf.reduce_sum(tf.cast(i, dtype='float32'))

    frac = tf.select(den==0., 0., num / den)

    return frac

frac = true_pos(y_true, y_pred)

In [74]:
# Generate the data
token_enc_train, char_enc_train, tags_enc_train, word_start_ind_train, word_end_ind_train = \
                                                                                        generate_padded_data(train)

In [75]:
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

In [87]:
fd = dict(zip(model.input, [token_enc_train[:2,:,],
                           char_enc_train[:2, :],
                           word_end_ind_train[:2, :]]))

In [132]:
for i in range(100):
    f, n, d = sess.run([frac, num, den], feed_dict=dict(zip(model.input + [y_true], [token_enc_train[i:i+1], 
                                                                     char_enc_train[i:i+1],
                                                                     word_end_ind_train[i:i+1],
                                                                     np.expand_dims(tags_enc_train[i:i+1], 2)]
                                         )
                                     )
            )
    

    print f, n, d

0.0 0.0 26.0
0.0 0.0 3.0
0.025641 1.0 39.0
0.0 0.0 40.0
0.047619 1.0 21.0
0.0952381 2.0 21.0
0.0 0.0 44.0
0.0 0.0 42.0
0.0 0.0 43.0
0.0 0.0 35.0
0.037037 2.0 54.0
0.0714286 2.0 28.0
0.08 2.0 25.0
0.03125 1.0 32.0
0.0655738 4.0 61.0
0.0434783 1.0 23.0
0.0714286 2.0 28.0
0.0 0.0 27.0
0.171429 6.0 35.0
0.04 2.0 50.0
0.0 0.0 41.0
0.030303 1.0 33.0
0.0 0.0 48.0
0.0 0.0 46.0
0.0 0.0 30.0
0.0 0.0 33.0
0.0 0.0 42.0
0.0 0.0 50.0
0.0526316 2.0 38.0
0.0 0.0 53.0
0.0689655 2.0 29.0
0.0 0.0 14.0
0.0 0.0 53.0
0.0 0.0 34.0
0.0 0.0 28.0
0.0 0.0 14.0
0.0 0.0 9.0
0.0 0.0 73.0
0.0 0.0 68.0
0.0454545 4.0 88.0
0.0833333 3.0 36.0
0.0714286 2.0 28.0
0.0 0.0 39.0
0.0 0.0 27.0
0.0 0.0 31.0
0.097561 4.0 41.0
0.215686 11.0 51.0


KeyboardInterrupt: 

## Old Segmentation Layer that didn't account for padding properly

In [4]:
class SegmentLayerOld(Layer):
    '''
    Takes a segmented sum
    '''
    def __init__(self, seg_func_name='sum', **kwargs):
        super(SegmentLayer, self).__init__(**kwargs)
        if seg_func_name == 'sum':
            self.seg_func = tf.segment_sum
        elif seg_func_name == 'mean':
            self.seg_func = tf.segment_mean
        elif seg_func_name == 'max':
            self.seg_func = tf.segment_max
        else:
            self.seg_func = tf.segment_sum
        
    def build(self, input_shape):
        super(SegmentLayer, self).build(input_shape)
    
    def compute_mask(self, x, mask=None):
        return K.cast(K.not_equal(x[1], -1), 'bool')
    
    def call(self, x, mask=None):
        rnn_inp = x[0]
        segment_mask = x[2]

        def f(inp):
            '''
            Performs a segmented sum on each input of (max_char_len, char_lstm_dim)
            '''
            mat = inp[0]
            seg_mask = inp[1]
            seg_sum = self.seg_func(mat, seg_mask)
            return seg_sum
        
        map_fn_out = tf.map_fn(f, elems=(rnn_inp, segment_mask), dtype='float32')
        
        return map_fn_out
    
    def get_output_shape_for(self, input_shape):
        rnn_shape = input_shape[0]
        ind_shape = input_shape[1]
        
        return (rnn_shape[0], ind_shape[1], rnn_shape[2])