In [4]:
import nltk
import numpy as np
import pickle as pkl
import pandas as pd
import tensorflow as tf
import keras
from keras import backend as K
from keras.models import Model
from keras.layers import Input, Dense, Lambda
from keras.utils import to_categorical, Sequence
from gensim.models import Doc2Vec
import os
import time
from nltk.tokenize import word_tokenize, sent_tokenize
import tensorflow_hub as hub
from tensorflow.errors import InvalidArgumentError

module_url = 'https://tfhub.dev/google/universal-sentence-encoder-large/3'
embed = hub.Module(module_url)

INFO:tensorflow:Using /tmp/tfhub_modules to cache modules.
INFO:tensorflow:Downloading TF-Hub Module 'https://tfhub.dev/google/universal-sentence-encoder-large/3'.
INFO:tensorflow:Downloaded https://tfhub.dev/google/universal-sentence-encoder-large/3, Total size: 810.60MB
INFO:tensorflow:Downloaded TF-Hub Module 'https://tfhub.dev/google/universal-sentence-encoder-large/3'.


In [2]:
os.chdir('..')

In [3]:
class_mapping = {}
with open('dbpedia_csv/classes.txt') as txt_file:
    line_num = 1
    for line in txt_file:
        class_mapping[line_num] = line.strip()
        line_num+=1
        
print(class_mapping)

{1: 'Company', 2: 'EducationalInstitution', 3: 'Artist', 4: 'Athlete', 5: 'OfficeHolder', 6: 'MeanOfTransportation', 7: 'Building', 8: 'NaturalPlace', 9: 'Village', 10: 'Animal', 11: 'Plant', 12: 'Album', 13: 'Film', 14: 'WrittenWork'}


In [4]:
class USE_Doc2Vec_DataGenerator(Sequence):
    def __init__(self, list_IDs, batch_size = 32, n_classes=14, shuffle=True):
        self.dim = (1,)
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.on_epoch_end()
    
    def __len__(self):
        return int(np.floor(len(self.list_IDs)/self.batch_size))
    
    def __getitem__(self, index):
        indexes = self.indexes[index:(index+1)]
        list_IDs_temp = [self.list_IDs[k] for k in indexes]
        X,y = self.__data_generation(list_IDs_temp)
        return X,y
    
    def on_epoch_end(self):
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
            
    def __data_generation(self, list_IDs_temp):
        X_USE = []
        X_Doc2Vec = np.empty((self.batch_size, 300,))
        y = np.empty((self.batch_size), dtype = int)
        for item in list_IDs_temp:
            Xarray = pkl.load(open('Training/'+item, 'rb'))
        for i, row in enumerate(Xarray[1]):
            cur_sent = ' '.join(row)
            X_USE.append(cur_sent)
            X_Doc2Vec[i,] = Doc2Vec_model.infer_vector(sent_tokenize(cur_sent)[-1])
            y[i] = Xarray[0]-1
        X_USE = np.array(X_USE,dtype = object)[:,np.newaxis]
        
        return [X_USE, X_Doc2Vec], to_categorical(y, num_classes = self.n_classes)

In [5]:
Doc2Vec_model = Doc2Vec.load('Doc2Vec/Doc2Vec.model')
Doc2Vec_model.delete_temporary_training_data(keep_doctags_vectors = True, keep_inference = True)

In [None]:
params = {'batch_size':32, 'n_classes':14,'shuffle':True}
partition, labels = pkl.load(open('Training/Metadata', 'rb'))
for key in labels.keys():
    labels[key] -= 1
training_generator = USE_Doc2Vec_DataGenerator(partition['train'], **params)

In [6]:
def UniversalEmbedding(x):
    return embed(tf.squeeze(tf.cast(x, tf.string)),signature='default', as_dict=True)['default']

In [7]:
USE_input = Input(shape=(1,), dtype = tf.string, name = 'USE_input')
embedding = Lambda(UniversalEmbedding, output_shape = (512, ))(USE_input)
USE_output = Dense(256, activation='relu')(embedding)
Doc2Vec_input = Input(shape = (300,), name = 'Doc2Vec_input')
Doc2Vec_dense = Dense(128, activation = 'relu')(Doc2Vec_input)
Doc2Vec_output = Dense(128, activation = 'relu')(Doc2Vec_dense)
concat = keras.layers.concatenate([USE_output, Doc2Vec_output])
pred = Dense(14, activation = 'softmax')(concat)
model = Model(inputs = [USE_input, Doc2Vec_input], outputs = pred)
model.compile(loss= 'categorical_crossentropy', optimizer = 'adam', metrics = ['accuracy'])

model.name = 'USEparagraph_Doc2Vecsentence'
loss_acc_path = model.name+'loss_acc'

INFO:tensorflow:Saver not created because there are no variables in the graph to restore


In [None]:
with tf.Session() as session:
    K.set_session(session)
    session.run(tf.global_variables_initializer())
    session.run(tf.tables_initializer())
    history = model.fit_generator(generator = training_generator, epochs = 1, verbose = 1)
    model.save_weights('./USE_Doc2Vec/'+model.name+'.h5')

In [10]:
with tf.Session() as session:
    K.set_session(session)
    session.run(tf.global_variables_initializer())
    session.run(tf.tables_initializer())
    model.load_weights('./USE_Doc2Vec/'+model.name+'.h5')  
    accuracies = pkl.load(open('./USE_Doc2Vec/USEparagraph_Doc2Vecsentence_accuracy', 'rb'))
    start = time.time()
    for label in os.listdir('Testing'):
        if label not in accuracies.keys():
            accuracies[label] = {}
        for sent_len in os.listdir('Testing/{}'.format(label)):
            if sent_len not in accuracies[label].keys():
                X_USE = []
                X_Doc2Vec = np.empty((0,300))
                labels = []
                if int(sent_len) < 100:
                    for file in os.listdir('Testing/{}/{}'.format(label, sent_len)):
                        #cur_sent = ' '.join(row)
#             X_USE.append(cur_sent)
#             X_Doc2Vec[i,] = Doc2Vec_model.infer_vector(sent_tokenize(cur_sent)[-1])
                        sample = pkl.load(open('Testing/{}/{}/{}'.format(label, sent_len, file),'rb'))
                        cur_sent = ' '.join(sample[1])
                        labels.append(sample[0])
                        X_USE.append(cur_sent)
                        X_Doc2Vec = np.vstack((X_Doc2Vec, Doc2Vec_model.infer_vector(sent_tokenize(cur_sent)[-1])))
                    X_USE = np.array(X_USE, dtype = object)[:,np.newaxis]
                    try:
                        result = np.asarray(model.predict([X_USE,X_Doc2Vec]))
                        result = np.argmax(result, axis = 1)+1
                        accuracy = np.sum(result == labels)/len(labels)
                        print('{}, sentence_length:{}, Accuracy:{}'.format(label, sent_len,accuracy))
                        accuracies[label][sent_len] = accuracy
                    except InvalidArgumentError:
                        print("Error with {} Sentence Length: {}".format(label, sent_len))
                        pass
                    print('elapsed time: {}'.format((time.time()-start)/60))
            pkl.dump(accuracies, open('./USE_Doc2Vec/USEparagraph_Doc2Vecsentence_accuracy', 'wb'))

Class9, sentence_length:27, Accuracy:0.9968093592129753
elapsed time: 1.43209175268809
Class9, sentence_length:74, Accuracy:0.9694444444444444
elapsed time: 1.6482942660649618
Class9, sentence_length:73, Accuracy:0.9686684073107049
elapsed time: 1.8775283058484395
Class9, sentence_length:15, Accuracy:0.9909851899549259
elapsed time: 3.18300089041392
Class9, sentence_length:2, Accuracy:0.7456
elapsed time: 4.130509332815806
Class9, sentence_length:16, Accuracy:0.9911465250110668
elapsed time: 5.448804144064585
Class9, sentence_length:10, Accuracy:0.975859987929994
elapsed time: 6.746833662192027
Class9, sentence_length:31, Accuracy:0.9975349219391948
elapsed time: 8.066641259193421
Class9, sentence_length:50, Accuracy:0.9939659901261657
elapsed time: 8.9037136276563
Class9, sentence_length:47, Accuracy:0.9950135992747053
elapsed time: 9.899036494890849
Class9, sentence_length:14, Accuracy:0.9883745508349187
elapsed time: 11.237745928764344
Class9, sentence_length:48, Accuracy:0.99568552

Class10, sentence_length:2, Accuracy:0.2354
elapsed time: 74.56546371777853
Error with Class10 Sentence Length: 16
elapsed time: 75.20958453416824
Class10, sentence_length:10, Accuracy:0.9496664645239539
elapsed time: 76.70873858531316
Class10, sentence_length:31, Accuracy:0.983942908117752
elapsed time: 77.57324966192246
Class10, sentence_length:50, Accuracy:0.9823754789272031
elapsed time: 78.2043839176496
Class10, sentence_length:47, Accuracy:0.9833101529902643
elapsed time: 78.88632001876832
Class10, sentence_length:97, Accuracy:1.0
elapsed time: 78.91228600343068
Class10, sentence_length:14, Accuracy:0.9751595259799453
elapsed time: 80.3127939661344
Class10, sentence_length:48, Accuracy:0.9828937990021382
elapsed time: 81.05427431662878
Class10, sentence_length:79, Accuracy:0.9858757062146892
elapsed time: 81.31410437027613
Class10, sentence_length:90, Accuracy:1.0
elapsed time: 81.41615285873414
Class10, sentence_length:12, Accuracy:0.9663093415007658
elapsed time: 82.67954920132

Class8, sentence_length:74, Accuracy:0.9913941480206541
elapsed time: 137.59672298034033
Class8, sentence_length:73, Accuracy:0.9900908340214699
elapsed time: 138.338949962457
Class8, sentence_length:15, Accuracy:0.9723247232472325
elapsed time: 139.41447155475618
Class8, sentence_length:2, Accuracy:0.5768
elapsed time: 140.44248097340267
Class8, sentence_length:16, Accuracy:0.9745808545159546
elapsed time: 141.48473694721858
Class8, sentence_length:10, Accuracy:0.9436917866215072
elapsed time: 142.74691752592722
Class8, sentence_length:31, Accuracy:0.9907131011608623
elapsed time: 143.87540151675543
Class8, sentence_length:50, Accuracy:0.9924343569203382
elapsed time: 145.05704837242763
Class8, sentence_length:47, Accuracy:0.9915433403805497
elapsed time: 146.17470109462738
Error with Class8 Sentence Length: 97
elapsed time: 146.22307190497716
Class8, sentence_length:14, Accuracy:0.967974180734856
elapsed time: 147.33797008593876
Class8, sentence_length:48, Accuracy:0.9918349806617963

Class8, sentence_length:43, Accuracy:0.992072929052715
elapsed time: 220.74717780749003
Class8, sentence_length:9, Accuracy:0.9318840579710145
elapsed time: 221.97878175973892
Class7, sentence_length:27, Accuracy:0.9645514223194749
elapsed time: 223.90774149099985
Class7, sentence_length:74, Accuracy:0.9704216488357458
elapsed time: 224.9330598394076
Class7, sentence_length:73, Accuracy:0.9672514619883041
elapsed time: 226.04641804297765
Class7, sentence_length:15, Accuracy:0.9457096380642538
elapsed time: 227.58585925896963
Class7, sentence_length:2, Accuracy:0.5274
elapsed time: 228.63774061203003
Class7, sentence_length:16, Accuracy:0.9488277268093782
elapsed time: 230.1685891230901
Class7, sentence_length:10, Accuracy:0.9204020100502512
elapsed time: 231.5204332749049
Class7, sentence_length:31, Accuracy:0.9647138656412576
elapsed time: 233.22005010843276
Class7, sentence_length:50, Accuracy:0.9691286785920369
elapsed time: 235.06961245536803
Class7, sentence_length:47, Accuracy:0.

Class7, sentence_length:19, Accuracy:0.9581097812628973
elapsed time: 338.4479668855667
Class7, sentence_length:70, Accuracy:0.9700598802395209
elapsed time: 339.69754914840064
Class7, sentence_length:36, Accuracy:0.9654676258992806
elapsed time: 341.46846413215
Class7, sentence_length:67, Accuracy:0.967885816235504
elapsed time: 342.8404311656952
Class7, sentence_length:43, Accuracy:0.965526247061896
elapsed time: 344.67842274109523
Class7, sentence_length:9, Accuracy:0.908381676335267
elapsed time: 345.95108224948245
Class3, sentence_length:27, Accuracy:0.9590997095837367
elapsed time: 347.5892359375954
Class3, sentence_length:74, Accuracy:0.977726574500768
elapsed time: 348.4627216418584
Class3, sentence_length:73, Accuracy:0.9767441860465116
elapsed time: 349.330135623614
Class3, sentence_length:15, Accuracy:0.7789779808890736
elapsed time: 350.7413659652074
Class3, sentence_length:2, Accuracy:0.2298
elapsed time: 351.7788255492846
Error with Class3 Sentence Length: 16
elapsed time

Class3, sentence_length:35, Accuracy:0.972001097996157
elapsed time: 442.8526273131371
Class3, sentence_length:32, Accuracy:0.9692227438706312
elapsed time: 444.2811814983686
Class3, sentence_length:46, Accuracy:0.9745538664904164
elapsed time: 445.70223683516184
Class3, sentence_length:71, Accuracy:0.9756258234519104
elapsed time: 446.6204936146736
Class3, sentence_length:85, Accuracy:0.9868421052631579
elapsed time: 446.93394864002863
Class3, sentence_length:19, Accuracy:0.9043062200956937
elapsed time: 448.65625017881393
Class3, sentence_length:70, Accuracy:0.9753164556962025
elapsed time: 449.6200214068095
Class3, sentence_length:36, Accuracy:0.9749512941831339
elapsed time: 451.1314513882001
Class3, sentence_length:67, Accuracy:0.9745434421693414
elapsed time: 452.18960061868034
Class3, sentence_length:43, Accuracy:0.9746875
elapsed time: 453.6783090035121
Class3, sentence_length:9, Accuracy:0.459915611814346
elapsed time: 454.89511450529096
Class6, sentence_length:27, Accuracy:0.

Class6, sentence_length:29, Accuracy:0.9857961053837343
elapsed time: 558.3664821743965
Class6, sentence_length:8, Accuracy:0.9026
elapsed time: 559.6321487665176
Class6, sentence_length:5, Accuracy:0.816
elapsed time: 560.7557052135468
Class6, sentence_length:35, Accuracy:0.9885799404170804
elapsed time: 562.4452477296194
Class6, sentence_length:32, Accuracy:0.9869141089697835
elapsed time: 564.0542524735133
Class6, sentence_length:46, Accuracy:0.9901477832512315
elapsed time: 565.7111143867174
Class6, sentence_length:71, Accuracy:0.9934711643090316
elapsed time: 566.8172794381777
Class6, sentence_length:85, Accuracy:0.99185667752443
elapsed time: 567.234567852815
Class6, sentence_length:19, Accuracy:0.9795031055900622
elapsed time: 568.7545189936956
Class6, sentence_length:70, Accuracy:0.9932185706833594
elapsed time: 569.914286271731
Class6, sentence_length:36, Accuracy:0.9894392758360573
elapsed time: 571.5867652535438
Class6, sentence_length:67, Accuracy:0.9929939280709948
elapsed

Class12, sentence_length:65, Accuracy:0.9953917050691244
elapsed time: 654.4263369679451
Class12, sentence_length:21, Accuracy:0.9897189856065799
elapsed time: 655.7911214868228
Class12, sentence_length:18, Accuracy:0.9837133550488599
elapsed time: 657.16838012139
Class12, sentence_length:98, Accuracy:1.0
elapsed time: 657.2090635418892
Class12, sentence_length:29, Accuracy:0.993108931884442
elapsed time: 658.5278127233188
Class12, sentence_length:8, Accuracy:0.8875550220088035
elapsed time: 659.738830546538
Class12, sentence_length:5, Accuracy:0.7275455091018204
elapsed time: 660.8190828720728
Class12, sentence_length:35, Accuracy:0.995875073659399
elapsed time: 662.1267422318458
Class12, sentence_length:32, Accuracy:0.9944008958566629
elapsed time: 663.4214367548625
Class12, sentence_length:46, Accuracy:0.9952554744525547
elapsed time: 664.6588048179945
Class12, sentence_length:71, Accuracy:0.9943019943019943
elapsed time: 665.4912100990613
Class12, sentence_length:85, Accuracy:0.992

Class5, sentence_length:58, Accuracy:0.9647946353730092
elapsed time: 754.7250420530637
Class5, sentence_length:34, Accuracy:0.9596061600605907
elapsed time: 756.3093405882518
Class5, sentence_length:83, Accuracy:0.9495495495495495
elapsed time: 756.6734996199608
Class5, sentence_length:65, Accuracy:0.9665981500513875
elapsed time: 757.8105977813403
Class5, sentence_length:21, Accuracy:0.9229473684210526
elapsed time: 759.3081868290901
Class5, sentence_length:18, Accuracy:0.898169101008023
elapsed time: 760.7773259560267
Class5, sentence_length:98, Accuracy:0.8421052631578947
elapsed time: 760.7936952789624
Class5, sentence_length:29, Accuracy:0.9518716577540107
elapsed time: 762.3366815050443
Class5, sentence_length:8, Accuracy:0.6515212169735789
elapsed time: 763.5840059161186
Class5, sentence_length:5, Accuracy:0.5341068213642729
elapsed time: 764.720709514618
Class5, sentence_length:35, Accuracy:0.9615483870967741
elapsed time: 766.2921705325444
Class5, sentence_length:32, Accuracy

Class2, sentence_length:23, Accuracy:0.9622518001309186
elapsed time: 855.399389886856
Class2, sentence_length:75, Accuracy:0.9733464955577492
elapsed time: 856.0298964222272
Class2, sentence_length:39, Accuracy:0.9660367022733498
elapsed time: 857.6081128954887
Class2, sentence_length:81, Accuracy:0.974903474903475
elapsed time: 857.9504545251528
Class2, sentence_length:58, Accuracy:0.9674864526886202
elapsed time: 859.2215968171755
Class2, sentence_length:34, Accuracy:0.9669442341660358
elapsed time: 860.789426736037
Class2, sentence_length:83, Accuracy:0.975609756097561
elapsed time: 861.0383578379949
Class2, sentence_length:65, Accuracy:0.9715468184169684
elapsed time: 862.1733839988708
Class2, sentence_length:21, Accuracy:0.9601286173633441
elapsed time: 863.6618417461713
Class2, sentence_length:18, Accuracy:0.9559159908504886
elapsed time: 865.1175014257431
Class2, sentence_length:98, Accuracy:0.9583333333333334
elapsed time: 865.1372157732645
Class2, sentence_length:29, Accuracy

Class14, sentence_length:60, Accuracy:0.987460815047022
elapsed time: 955.1350052396457
Class14, sentence_length:88, Accuracy:0.9908675799086758
elapsed time: 955.2888711849848
Class14, sentence_length:59, Accuracy:0.9875161429186397
elapsed time: 956.5779279152553
Class14, sentence_length:3, Accuracy:0.304
elapsed time: 957.5859675804774
Class14, sentence_length:66, Accuracy:0.9865168539325843
elapsed time: 958.5953816692034
Class14, sentence_length:23, Accuracy:0.9719955898566703
elapsed time: 960.4274215022723
Class14, sentence_length:75, Accuracy:0.9859154929577465
elapsed time: 961.0705161611239
Class14, sentence_length:39, Accuracy:0.9854074889867841
elapsed time: 962.6164145747821
Class14, sentence_length:81, Accuracy:0.9889937106918238
elapsed time: 963.038089911143
Class14, sentence_length:58, Accuracy:0.986644407345576
elapsed time: 964.3254647254944
Class14, sentence_length:34, Accuracy:0.9822018815153827
elapsed time: 965.8292665243149
Class14, sentence_length:83, Accuracy:

Error with Class13 Sentence Length: 64
elapsed time: 1049.4464662631353
Class13, sentence_length:24, Accuracy:0.9630077369439072
elapsed time: 1050.8973367214203
Class13, sentence_length:89, Accuracy:0.950920245398773
elapsed time: 1051.0154514630635
Class13, sentence_length:63, Accuracy:0.9642637091805298
elapsed time: 1051.9352574944496
Class13, sentence_length:49, Accuracy:0.9653767820773931
elapsed time: 1053.1210157076518
Error with Class13 Sentence Length: 60
elapsed time: 1054.121455037594
Class13, sentence_length:88, Accuracy:0.9641025641025641
elapsed time: 1054.2609872261683
Class13, sentence_length:59, Accuracy:0.9639784946236559
elapsed time: 1055.2518022934596
Class13, sentence_length:3, Accuracy:0.3416
elapsed time: 1056.2862033089002
Class13, sentence_length:66, Accuracy:0.9638135003479471
elapsed time: 1057.1157077511152
Class13, sentence_length:23, Accuracy:0.9603889943074004
elapsed time: 1058.4972744385402
Class13, sentence_length:75, Accuracy:0.9667036625971143
elap

Class11, sentence_length:80, Accuracy:0.9953488372093023
elapsed time: 1133.2267489075662
Class11, sentence_length:4, Accuracy:0.3876
elapsed time: 1134.322490398089
Class11, sentence_length:11, Accuracy:0.8838867386433082
elapsed time: 1135.592209049066
Class11, sentence_length:64, Accuracy:0.9917751884852639
elapsed time: 1136.429509182771
Class11, sentence_length:24, Accuracy:0.9585336538461539
elapsed time: 1137.5853606462479
Class11, sentence_length:89, Accuracy:0.9935064935064936
elapsed time: 1137.6968787590663
Class11, sentence_length:63, Accuracy:0.989926124916051
elapsed time: 1138.6913086136183
Class11, sentence_length:49, Accuracy:0.9819819819819819
elapsed time: 1139.655784300963
Class11, sentence_length:60, Accuracy:0.990613266583229
elapsed time: 1140.533094437917
Class11, sentence_length:88, Accuracy:0.9944444444444445
elapsed time: 1140.6530225714048
Class11, sentence_length:59, Accuracy:0.9908200734394125
elapsed time: 1141.5618997454644
Class11, sentence_length:3, Ac

Class4, sentence_length:99, Accuracy:1.0
elapsed time: 1218.5322645227113
Class4, sentence_length:41, Accuracy:0.9954321855235418
elapsed time: 1219.7305363933244
Class4, sentence_length:82, Accuracy:0.9963503649635036
elapsed time: 1220.096635401249
Class4, sentence_length:1, Accuracy:0.6582
elapsed time: 1221.1092551231384
Class4, sentence_length:91, Accuracy:1.0
elapsed time: 1221.199092610677
Class4, sentence_length:80, Accuracy:0.9969924812030075
elapsed time: 1221.6379808624586
Class4, sentence_length:4, Accuracy:0.596958174904943
elapsed time: 1222.8112057288488
Class4, sentence_length:11, Accuracy:0.7074638844301766
elapsed time: 1224.1789991259575
Class4, sentence_length:64, Accuracy:0.9936427209154481
elapsed time: 1225.1716450889905
Class4, sentence_length:24, Accuracy:0.9801641586867305
elapsed time: 1227.057721710205
Class4, sentence_length:89, Accuracy:1.0
elapsed time: 1227.2020302255949
Class4, sentence_length:63, Accuracy:0.9938537185003073
elapsed time: 1228.140089289

Class1, sentence_length:53, Accuracy:0.9420393559928444
elapsed time: 1315.1121275623639
Class1, sentence_length:54, Accuracy:0.9404849375459221
elapsed time: 1316.4704367399215
Class1, sentence_length:20, Accuracy:0.9495674192867694
elapsed time: 1317.959111126264
Class1, sentence_length:77, Accuracy:0.8995502248875562
elapsed time: 1318.3796020587286
Class1, sentence_length:99, Accuracy:0.9473684210526315
elapsed time: 1318.396418873469
Class1, sentence_length:41, Accuracy:0.9487465181058495
elapsed time: 1319.9839559674263
Class1, sentence_length:82, Accuracy:0.8823529411764706
elapsed time: 1320.2101809978485
Class1, sentence_length:1, Accuracy:0.2294
elapsed time: 1320.9373056054114
Class1, sentence_length:91, Accuracy:0.890625
elapsed time: 1320.9848118344942
Class1, sentence_length:80, Accuracy:0.8857758620689655
elapsed time: 1321.28648481369
Class1, sentence_length:4, Accuracy:0.6953390678135627
elapsed time: 1322.3934480230014
Class1, sentence_length:11, Accuracy:0.9093102754