In [2]:
import keras
from keras.models import Model
from keras.layers import Input, LSTM, Dense, Bidirectional, Embedding, Reshape, Dropout
from keras.preprocessing.sequence import pad_sequences
import numpy as np
from sklearn.model_selection import train_test_split




In [9]:
from tqdm import tqdm

def test_model(dtest, model, char2idx, maxlen):
    np.set_printoptions(precision=2, suppress=True)
    passed = 0
    failed = 0
    total_samasa = 0
    correct_samasa = 0
    inputs = []
    targets = []
    for data in dtest:
        target = np.array(list(data[1]))
        input_word = data[0]
    
        inputs.append(input_word)
        targets.append(target)
    
    X_test = [[char2idx[c] for c in w] for w in inputs]
    X_test = pad_sequences(maxlen=maxlen, sequences=X_test, padding="post", value=char2idx['*'])
    
    Y_test = targets
    Y_test = pad_sequences(maxlen=maxlen, sequences=Y_test, padding="post", value=0.0)
    Y_test = np.array(Y_test).reshape(-1, maxlen, 1)
   
    startlist = []
    for i in tqdm(range(X_test.shape[0])):
        test = X_test[i].reshape((-1, maxlen))
        res = model.predict(test, verbose=0)
        res = res.reshape((maxlen))
        dup = np.copy(res)
        act = Y_test[i].reshape((maxlen))

        wordlen = 0
        for j in range(maxlen):
            if X_test[i][j] == char2idx['*']:
                break
            else:
                wordlen = wordlen + 1

        res = res[0:wordlen]
        act = act[0:wordlen]
        origres = res
        
        for j in range(wordlen):
            if(res[j] >= 0.5):
                res[j] = 1
            else:
                res[j] = 0
                
        ires = res.astype(int)
        iact = act.astype(int)
        temp = np.multiply(ires, iact)
        total_samasa = total_samasa + np.sum(iact)
        correct_samasa = correct_samasa + np.sum(temp)

        comparison = ires == iact
        
        if comparison.all():
            passed = passed + 1
        else:
            failed = failed + 1

    print(passed)
    print(failed)
    print(passed*100/(passed+failed))
    print(correct_samasa)
    print(total_samasa)
    print(correct_samasa*100/total_samasa)

    return startlist

In [10]:
# Load model and test files
model = keras.models.load_model('stage1_bilstm.h5', compile=False)
fh = open('stage1_char2idx.txt', 'r')
data = fh.read()
char2idx = eval(data)
fh.close()
file = open("dtest.csv", "r")
lines = file.readlines()
file.close()
dtest = []
for line in lines:
    dtest.append(line.strip().split(','))
test_model(dtest, model, char2idx, 72)

100%|██████████| 17304/17304 [26:59<00:00, 10.69it/s] 

15036
2268
86.89320388349515
23717
25861
91.70952399365841





[]

In [9]:
from tqdm import tqdm

max_encoder_seq_length = 6
max_decoder_seq_length = 11
num_tokens = 52

def decode_sequence(input_seq, encoder_model, decoder_model, reverse_target_char_index):
    # Encode the input as state vectors.
    states_value = encoder_model.predict(input_seq, verbose=0)

    # Generate empty target sequence of length 1.
    target_seq = np.zeros((1, 1, num_tokens))
    # Populate the first character of target sequence with the start character.
    target_seq[0, 0, token_index['&']] = 1.

    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = ''
    while not stop_condition:
        output_tokens, h, c = decoder_model.predict([target_seq] + states_value, verbose=0)

        # Sample a token
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_char = reverse_target_char_index[sampled_token_index]
        decoded_sentence += sampled_char

        # Exit condition: either hit max length
        # or find stop character.
        if (sampled_char == '$' or
           len(decoded_sentence) > max_decoder_seq_length):
            stop_condition = True

        # Update the target sequence (of length 1).
        target_seq = np.zeros((1, 1, num_tokens))
        target_seq[0, 0, sampled_token_index] = 1.

        # Update states
        states_value = [h, c]

    return decoded_sentence

def infer_sandhi_split(dtest, encoder_model, decoder_model, token_index):
    input_texts = []
    target_texts = []

    for data in dtest:
        [input_text, target_text] = data.split(',')
        input_texts.append(input_text)
        target_texts.append(target_text)

    encoder_input_data = np.zeros((len(input_texts), max_encoder_seq_length, num_tokens), dtype='float32')
    
    for i, input_text in enumerate(input_texts):
        for t, char in enumerate(input_text):
            if char not in token_index:
                continue
            encoder_input_data[i, t, token_index[char]] = 1.
        encoder_input_data[i, t + 1:, token_index['*']] = 1.
    
    # Reverse-lookup token index to decode sequences back to something readable.
    reverse_input_char_index = dict((i, char) for char, i in token_index.items())
    reverse_target_char_index = dict((i, char) for char, i in token_index.items())
    
    total = len(encoder_input_data)
    passed = 0
    results = []
    for seq_index in tqdm(range(len(encoder_input_data))):
        # Take one sequence (part of the training set)
        # for trying out decoding.
        input_seq = encoder_input_data[seq_index: seq_index + 1]
        decoded_sentence = decode_sequence(input_seq, encoder_model, decoder_model, reverse_target_char_index)
        decoded_sentence = decoded_sentence.strip()
        decoded_sentence = decoded_sentence.strip('$')
        results.append(decoded_sentence)
        if decoded_sentence.strip() == target_texts[seq_index]:
            passed = passed + 1
        else:
            print(input_texts[seq_index]+" "+str(target_texts[seq_index])+" "+str(decoded_sentence))

    print("Passed: "+str(passed)+'/'+str(total)+', '+str(passed*100/total))


encoder_model = keras.models.load_model('stage2_encoder.h5')
decoder_model = keras.models.load_model('stage2_decoder.h5')
fh = open('stage2_token_index.txt', 'r')
data = fh.read()
token_index = eval(data)
fh.close()
file = open("stage2_dtest.csv", "r")
lines = file.readlines()
file.close()
dtest = []
for line in lines:
    dtest.append(line.strip())

infer_sandhi_split(dtest[0:100], encoder_model, decoder_model, token_index)

  4%|▍         | 4/100 [00:04<01:33,  1.03it/s]

pag pa-g ga-p



 10%|█         | 10/100 [00:08<00:56,  1.60it/s]

Aj A-Aj A-j



 11%|█         | 11/100 [00:08<00:55,  1.60it/s]

akAD aka-aD aka-A



 19%|█▉        | 19/100 [00:13<00:47,  1.70it/s]

Ada A-Ada A-da



 27%|██▋       | 27/100 [00:18<00:52,  1.40it/s]

atAr ata-ar ata-Ar



 28%|██▊       | 28/100 [00:19<00:50,  1.43it/s]

Uv U-v a-v



 30%|███       | 30/100 [00:20<00:56,  1.23it/s]

ahoda ahA-uda aha-uda



 33%|███▎      | 33/100 [00:22<00:42,  1.56it/s]

atv a-tv at-v



 42%|████▏     | 42/100 [00:27<00:33,  1.71it/s]

DAra DA-ra Da-arT



 46%|████▌     | 46/100 [00:29<00:32,  1.64it/s]

radaz rat-az ra-dz



 53%|█████▎    | 53/100 [00:34<00:30,  1.54it/s]

agv ak-v a-v



 65%|██████▌   | 65/100 [00:42<00:24,  1.40it/s]

poz pa-uz pa-ud



 66%|██████▌   | 66/100 [00:43<00:22,  1.50it/s]

fm f-m i-m



 70%|███████   | 70/100 [00:45<00:19,  1.52it/s]

arAN ara-aN ara-Ag



 73%|███████▎  | 73/100 [00:48<00:19,  1.40it/s]

anoma anaH-ma ana-uma



 74%|███████▍  | 74/100 [00:49<00:21,  1.24it/s]

ehAr eha-ar Aha-ar



 89%|████████▉ | 89/100 [00:58<00:06,  1.81it/s]

zob zaH-b za-uk



 97%|█████████▋| 97/100 [01:04<00:02,  1.42it/s]

jova jaH-va ja-uja



100%|██████████| 100/100 [01:06<00:00,  1.50it/s]

Passed: 82/100, 82.0





In [38]:
dtest[1]

''