In [72]:
import tensorflow as tf
import keras as K
import numpy as np 
from sklearn.model_selection import train_test_split
from keras.preprocessing.sequence import pad_sequences
from keras.datasets import imdb
import io
from os import path
from keras.layers import Bidirectional, LSTM

PAD = "<PAD>"
START = "<START>"
UNK = "<UNK>"
MAX_SIZE = 50
HIDDEN_SIZE = 256

In [88]:
def build_vocab(data):
    word_to_id = dict()
    word_to_id['<PAD>'] = 0
    word_to_id['<START>'] = 1
    word_to_id['<UNK>'] = 2
    index = 3
    if isinstance(data, io.TextIOWrapper):
        for line in data:
            line = line.strip()
            for i in range(len(line)):
                if line[i] not in word_to_id:
                    word_to_id[line[i]] = index
                    index += 1
                if i < len(line) - 2: 
                    if line[i:i+2] not in word_to_id:
                        word_to_id[line[i:i+2]] = index
                        index += 1
    elif isinstance(data, str):
        for i in range(len(data)):
            if data[i] not in word_to_id:
                word_to_id[data[i]] = index
                index += 1
            if i < len(data) - 2:
                if data[i:i+2] not in word_to_id:
                    word_to_id[data[i:i+2]] = index
                    index += 1
        
    id_to_word = {v:k for k,v in word_to_id.items()}
    
    return word_to_id, id_to_word

with open("/Users/petercbu/Google Drive/University/2019 Spring/NLP/hw1_nlp_sapienza_2019/resources/train/input/as.txt") as f:
    word_to_id_as, id_to_word_as = build_vocab(f)


701079


In [91]:
VOCAB_SIZE = len(word_to_id_as)

701079


In [92]:
def create_input_dataset(file, word_to_id):
    x = []
    for line in file:
        feature_vector = []
        feature_vector.append(word_to_id[START])
        
        # Build feature vector
        for i in range(len(line)):
            unigram = line[i]
            if unigram in word_to_id:
                feature_vector.append(word_to_id[unigram])
            else:
                feature_vector.append(word_to_id[UNK])
            
            if i < len(line) - 2:
                bigram = line[i:i+2]
                if bigram in word_to_id:
                    feature_vector.append(word_to_id[bigram])
                else:
                    feature_vector.append(word_to_id[UNK])
        
        x.append(np.array(feature_vector))
    return np.array(x)

In [93]:
def BIES_to_numerical(file_path):
    BIES_to_number = {'B': 0, 'I': 1, 'E': 2, 'S': 3}
    y = []
    with open(file_path, 'r', encoding='utf-8') as f:
        
        for line in f:
            line = line.strip()
            new_line = []
            for ch in line:
                new_line.append(str(BIES_to_number[ch]))
            y.append(new_line)
    return np.array(y)

In [94]:
with open("/Users/petercbu/Google Drive/University/2019 Spring/NLP/hw1_nlp_sapienza_2019/resources/train/input/as.txt") as f:    
    train_x_as = create_input_dataset(f, word_to_id_as)

with open("/Users/petercbu/Google Drive/University/2019 Spring/NLP/hw1_nlp_sapienza_2019/resources/dev/input/as.txt") as f:
    dev_x_as = create_input_dataset(f, word_to_id_as)
    
train_y_as = BIES_to_numerical("/Users/petercbu/Google Drive/University/2019 Spring/NLP/hw1_nlp_sapienza_2019/resources/train/labels/as.txt")
dev_y_as = BIES_to_numerical("/Users/petercbu/Google Drive/University/2019 Spring/NLP/hw1_nlp_sapienza_2019/resources/dev/labels/as.txt")

train_x_as = pad_sequences(train_x_as, truncating='pre', padding='post', maxlen=MAX_SIZE)
dev_x_as = pad_sequences(dev_x_as, truncating='pre', padding='post', maxlen=MAX_SIZE)
train_y_as = pad_sequences(train_y_as, truncating='pre', padding='post', maxlen=MAX_SIZE)
dev_y_as = pad_sequences(dev_y_as, truncating='pre', padding='post', maxlen=MAX_SIZE)

#print(train_x_as[0:2])
#print(train_y_as[0:2])
#print(dev_x_as[0:2])
#print(dev_y_as[0:2])


In [95]:
train_y_as = K.utils.to_categorical(train_y_as, 4, dtype='int')
dev_y_as = K.utils.to_categorical(dev_y_as, 4, dtype='int')

print(train_y_as.shape)
print(dev_y_as.shape)

(708953, 50, 4)
(14432, 50, 4)


In [96]:
def create_keras_model(vocab_size, embedding_size, hidden_size, dropout, recurrent_dropout):
    print("Creating KERAS model")

    # define LSTM
    model = K.models.Sequential()
    model.add(K.layers.Embedding(vocab_size, embedding_size, mask_zero=True))
    model.add(Bidirectional(LSTM(hidden_size, dropout=dropout, recurrent_dropout=recurrent_dropout, return_sequences=True)))
    model.add(K.layers.Dense(4, activation='softmax'))
    
    # we are going to use the Adam optimizer which is a really powerful optimizer.
    optimizer = K.optimizers.Adam()
    model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['acc'])

    return model

In [None]:
batch_size = 32
epochs = 3
EMBEDDING_SIZE = 150
model = create_keras_model(VOCAB_SIZE, EMBEDDING_SIZE, HIDDEN_SIZE, 0.2, 0.2)
# Let's print a summary of the model
model.summary()

cbk = K.callbacks.TensorBoard("logging/keras_model")
print("\nStarting training...")
model.fit(train_x_as, train_y_as, epochs=epochs, batch_size=batch_size,
          shuffle=True, validation_data=(dev_x_as, dev_y_as), callbacks=[cbk]) 
print("Training complete.\n")

#print("\nEvaluating test...")
#loss_acc = model.evaluate(test_x, test_y, verbose=0)
#print("Test data: loss = %0.6f  accuracy = %0.2f%% " % (loss_acc[0], loss_acc[1]*100))

Creating KERAS model


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_4 (Embedding)      (None, None, 150)         105161850 
_________________________________________________________________
bidirectional_4 (Bidirection (None, None, 512)         833536    
_________________________________________________________________
dense_4 (Dense)              (None, None, 4)           2052      
Total params: 105,997,438
Trainable params: 105,997,438
Non-trainable params: 0
_________________________________________________________________

Starting training...


Train on 708953 samples, validate on 14432 samples


Epoch 1/3


    32/708953 [..............................] - ETA: 47:51:05 - loss: 0.5617 - acc: 0.7500

    64/708953 [..............................] - ETA: 30:50:32 - loss: 0.5590 - acc: 0.7500

    96/708953 [..............................] - ETA: 25:29:01 - loss: 0.5560 - acc: 0.7500

   128/708953 [..............................] - ETA: 23:19:00 - loss: 0.5522 - acc: 0.7500

   160/708953 [..............................] - ETA: 21:42:17 - loss: 0.5473 - acc: 0.7500

   192/708953 [..............................] - ETA: 20:22:02 - loss: 0.5408 - acc: 0.7500

   224/708953 [..............................] - ETA: 19:42:57 - loss: 0.5318 - acc: 0.7500

   256/708953 [..............................] - ETA: 19:11:55 - loss: 0.5197 - acc: 0.7572

   288/708953 [..............................] - ETA: 18:52:11 - loss: 0.5041 - acc: 0.7660

   320/708953 [..............................] - ETA: 18:41:47 - loss: 0.4887 - acc: 0.7740

   352/708953 [..............................] - ETA: 18:23:22 - loss: 0.4812 - acc: 0.7803

   384/708953 [..............................] - ETA: 18:10:43 - loss: 0.4709 - acc: 0.7864

   416/708953 [..............................] - ETA: 17:57:14 - loss: 0.4614 - acc: 0.7918

   448/708953 [..............................] - ETA: 18:03:28 - loss: 0.4638 - acc: 0.7946

   480/708953 [..............................] - ETA: 17:52:42 - loss: 0.4568 - acc: 0.7984

   512/708953 [..............................] - ETA: 17:54:56 - loss: 0.4503 - acc: 0.8016

   544/708953 [..............................] - ETA: 18:00:47 - loss: 0.4448 - acc: 0.8048

   576/708953 [..............................] - ETA: 18:00:54 - loss: 0.4402 - acc: 0.8079

   608/708953 [..............................] - ETA: 17:52:00 - loss: 0.4368 - acc: 0.8101

   640/708953 [..............................] - ETA: 17:40:40 - loss: 0.4328 - acc: 0.8124

   672/708953 [..............................] - ETA: 17:35:40 - loss: 0.4329 - acc: 0.8132

   704/708953 [..............................] - ETA: 17:27:58 - loss: 0.4290 - acc: 0.8151

   736/708953 [..............................] - ETA: 17:24:23 - loss: 0.4248 - acc: 0.8168

   768/708953 [..............................] - ETA: 17:18:46 - loss: 0.4228 - acc: 0.8181

   800/708953 [..............................] - ETA: 17:14:11 - loss: 0.4195 - acc: 0.8195

   832/708953 [..............................] - ETA: 17:14:51 - loss: 0.4161 - acc: 0.8211

   864/708953 [..............................] - ETA: 17:09:54 - loss: 0.4135 - acc: 0.8224

   896/708953 [..............................] - ETA: 17:06:46 - loss: 0.4111 - acc: 0.8235

   928/708953 [..............................] - ETA: 17:03:34 - loss: 0.4075 - acc: 0.8249

   960/708953 [..............................] - ETA: 16:58:06 - loss: 0.4055 - acc: 0.8257

   992/708953 [..............................] - ETA: 16:50:27 - loss: 0.4028 - acc: 0.8268

  1024/708953 [..............................] - ETA: 16:45:55 - loss: 0.4001 - acc: 0.8280

  1056/708953 [..............................] - ETA: 16:45:37 - loss: 0.3974 - acc: 0.8290

  1088/708953 [..............................] - ETA: 16:55:19 - loss: 0.3959 - acc: 0.8296

  1120/708953 [..............................] - ETA: 16:58:40 - loss: 0.3931 - acc: 0.8303

  1152/708953 [..............................] - ETA: 16:55:15 - loss: 0.3921 - acc: 0.8305

  1184/708953 [..............................] - ETA: 16:49:24 - loss: 0.3896 - acc: 0.8312

  1216/708953 [..............................] - ETA: 16:44:58 - loss: 0.3875 - acc: 0.8318

  1248/708953 [..............................] - ETA: 17:18:40 - loss: 0.3858 - acc: 0.8323

  1280/708953 [..............................] - ETA: 17:26:12 - loss: 0.3837 - acc: 0.8330

  1312/708953 [..............................] - ETA: 17:24:04 - loss: 0.3823 - acc: 0.8333

  1344/708953 [..............................] - ETA: 17:23:39 - loss: 0.3805 - acc: 0.8339

  1376/708953 [..............................] - ETA: 17:29:02 - loss: 0.3787 - acc: 0.8345

  1408/708953 [..............................] - ETA: 17:43:45 - loss: 0.3769 - acc: 0.8350

  1440/708953 [..............................] - ETA: 17:43:32 - loss: 0.3754 - acc: 0.8353

  1472/708953 [..............................] - ETA: 17:39:51 - loss: 0.3739 - acc: 0.8359

  1504/708953 [..............................] - ETA: 17:34:27 - loss: 0.3726 - acc: 0.8362

  1536/708953 [..............................] - ETA: 17:29:09 - loss: 0.3718 - acc: 0.8364

  1568/708953 [..............................] - ETA: 17:36:00 - loss: 0.3702 - acc: 0.8368

  1600/708953 [..............................] - ETA: 17:40:28 - loss: 0.3689 - acc: 0.8373

  1632/708953 [..............................] - ETA: 17:55:55 - loss: 0.3676 - acc: 0.8378

  1664/708953 [..............................] - ETA: 18:08:33 - loss: 0.3668 - acc: 0.8381

  1696/708953 [..............................] - ETA: 18:13:48 - loss: 0.3655 - acc: 0.8386

  1728/708953 [..............................] - ETA: 18:28:13 - loss: 0.3656 - acc: 0.8386

  1760/708953 [..............................] - ETA: 18:41:36 - loss: 0.3643 - acc: 0.8390

  1792/708953 [..............................] - ETA: 19:00:23 - loss: 0.3632 - acc: 0.8393

  1824/708953 [..............................] - ETA: 19:12:29 - loss: 0.3622 - acc: 0.8396

  1856/708953 [..............................] - ETA: 19:20:59 - loss: 0.3611 - acc: 0.8400

  1888/708953 [..............................] - ETA: 19:26:22 - loss: 0.3603 - acc: 0.8404

  1920/708953 [..............................] - ETA: 19:29:52 - loss: 0.3591 - acc: 0.8407

  1952/708953 [..............................] - ETA: 19:38:12 - loss: 0.3582 - acc: 0.8410

  1984/708953 [..............................] - ETA: 19:55:40 - loss: 0.3573 - acc: 0.8413

  2016/708953 [..............................] - ETA: 20:03:15 - loss: 0.3564 - acc: 0.8416

  2048/708953 [..............................] - ETA: 20:02:11 - loss: 0.3560 - acc: 0.8418

  2080/708953 [..............................] - ETA: 20:01:03 - loss: 0.3550 - acc: 0.8421

  2112/708953 [..............................] - ETA: 20:04:28 - loss: 0.3541 - acc: 0.8425

  2144/708953 [..............................] - ETA: 20:04:23 - loss: 0.3533 - acc: 0.8427

  2176/708953 [..............................] - ETA: 20:00:57 - loss: 0.3530 - acc: 0.8429

  2208/708953 [..............................] - ETA: 19:58:32 - loss: 0.3524 - acc: 0.8431

  2240/708953 [..............................] - ETA: 19:54:07 - loss: 0.3518 - acc: 0.8434

  2272/708953 [..............................] - ETA: 19:49:59 - loss: 0.3510 - acc: 0.8437

  2304/708953 [..............................] - ETA: 19:52:23 - loss: 0.3503 - acc: 0.8440

  2336/708953 [..............................] - ETA: 19:53:18 - loss: 0.3496 - acc: 0.8442

  2368/708953 [..............................] - ETA: 19:54:42 - loss: 0.3491 - acc: 0.8442

  2400/708953 [..............................] - ETA: 19:55:08 - loss: 0.3484 - acc: 0.8444

  2432/708953 [..............................] - ETA: 19:53:29 - loss: 0.3478 - acc: 0.8445

  2464/708953 [..............................] - ETA: 19:52:24 - loss: 0.3473 - acc: 0.8447

  2496/708953 [..............................] - ETA: 19:48:40 - loss: 0.3467 - acc: 0.8450

  2528/708953 [..............................] - ETA: 19:48:24 - loss: 0.3460 - acc: 0.8452

  2560/708953 [..............................] - ETA: 19:51:25 - loss: 0.3455 - acc: 0.8454

  2592/708953 [..............................] - ETA: 19:47:18 - loss: 0.3448 - acc: 0.8457

  2624/708953 [..............................] - ETA: 19:43:38 - loss: 0.3444 - acc: 0.8458

  2656/708953 [..............................] - ETA: 19:39:24 - loss: 0.3438 - acc: 0.8460

  2688/708953 [..............................] - ETA: 19:35:03 - loss: 0.3433 - acc: 0.8461

  2720/708953 [..............................] - ETA: 19:30:28 - loss: 0.3429 - acc: 0.8463

  2752/708953 [..............................] - ETA: 19:25:59 - loss: 0.3423 - acc: 0.8465

  2784/708953 [..............................] - ETA: 19:22:22 - loss: 0.3420 - acc: 0.8466

  2816/708953 [..............................] - ETA: 19:18:32 - loss: 0.3415 - acc: 0.8468

  2848/708953 [..............................] - ETA: 19:14:26 - loss: 0.3410 - acc: 0.8470

  2880/708953 [..............................] - ETA: 19:10:47 - loss: 0.3406 - acc: 0.8472

  2912/708953 [..............................] - ETA: 19:08:09 - loss: 0.3399 - acc: 0.8474

  2944/708953 [..............................] - ETA: 19:04:08 - loss: 0.3397 - acc: 0.8474

  2976/708953 [..............................] - ETA: 18:59:17 - loss: 0.3393 - acc: 0.8475

  3008/708953 [..............................] - ETA: 18:55:18 - loss: 0.3388 - acc: 0.8477

  3040/708953 [..............................] - ETA: 18:52:04 - loss: 0.3383 - acc: 0.8478

  3072/708953 [..............................] - ETA: 18:48:43 - loss: 0.3380 - acc: 0.8479

  3104/708953 [..............................] - ETA: 18:45:21 - loss: 0.3376 - acc: 0.8480

  3136/708953 [..............................] - ETA: 18:49:33 - loss: 0.3372 - acc: 0.8482

  3168/708953 [..............................] - ETA: 18:51:10 - loss: 0.3368 - acc: 0.8484

  3200/708953 [..............................] - ETA: 18:50:07 - loss: 0.3368 - acc: 0.8483

  3232/708953 [..............................] - ETA: 18:49:08 - loss: 0.3366 - acc: 0.8484

  3264/708953 [..............................] - ETA: 18:47:53 - loss: 0.3363 - acc: 0.8485

  3296/708953 [..............................] - ETA: 18:47:47 - loss: 0.3358 - acc: 0.8487