In [1]:
import numpy as np
import tensorflow as tf

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  return f(*args, **kwds)


In [2]:
from tensorflow.keras.layers import Dropout, concatenate, LSTM, Dense, concatenate, Embedding
embed_len = 32
vocab_size = 223788
class QUA_Net(tf.keras.Model):
    def __init__(self):
        super(QUA_Net, self).__init__()
        self.q_embed = Embedding(vocab_size, embed_len)
        self.s_embed = Embedding(vocab_size, embed_len)
        self.q_lstm = LSTM(128)
        self.s_lstm = LSTM(128)
        self.dropout = Dropout(0.33) #0.3
        self.dense = Dense(64, activation='relu')
        self.dense2 = Dense(1, activation = 'sigmoid')
        
    def call(self, inputs):
        q = self.q_embed(inputs[0])
        s = self.s_embed(inputs[1])
        
        q_out = self.q_lstm(q)

        s_out = self.s_lstm(s)
        
        merge = concatenate([q_out, s_out], axis = -1)
        drop = self.dropout(merge)
        drop2 = self.dense(drop)
        out = self.dense2(drop2)
        return out

model = QUA_Net()

In [20]:
import ujson as json
import json

with open('word2idx.json') as f:
    word2idx = json.load(f)

idx2word = {int(v): k for k, v in word2idx.items()}

with open('idx2word.json', 'w') as outfile:
    json.dump(idx2word, outfile)
    

In [35]:
import ujson as json
import json
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.preprocessing.text import one_hot
from nltk import tokenize

with open('idx2word.json') as f:
    idx2word = json.load(f)



def translate_idxs(words):
    translated = ""
    for line in words:
        for idx in line:
            if idx == 1 or idx == 0:
                continue
            #print(idx2word[idx])
            translated  += idx2word[str(idx)] + " "
    return translated
        

'''
Input:
cw_idxs: indeces for the context paragraph. you can translate this to a sentence with indx2word
qw_idxs: question to answer

Output:
Reduced context based on QuaNet output

'''
def transform_context(cw_idxs, qw_idxs):
    translated_context = translate_idxs(cw_idxs)
    context_sentences = tokenize.sent_tokenize(translated_context)
    #translated_q = [translate_idxs(qw_idxs)] * len(context_sentences)
    translated_q = ["What is a prime number"] * len(context_sentences)

    one_hot_sentences = [one_hot(s, vocab_size) for s in context_sentences]
    padded_sentences = sequence.pad_sequences(one_hot_sentences, maxlen=323)
    
    one_hot_qs = [one_hot(q, vocab_size) for q in translated_q]
    padded_qs = sequence.pad_sequences(one_hot_qs, maxlen=323)
    
    print(translated_q[0])
    predictions = model.predict([padded_qs, padded_sentences])
    best_sent = np.argmax(predictions)
    print(context_sentences[best_sent])
    print(predictions)

array = np.load("stephanie_ex.npy")
print(array)
transform_context(array, [[1, 5,3,1, 80, 1307, 7]])

[[   1   80 5137 ...    0    0    0]
 [   1   80 5137 ...    0    0    0]
 [   1   80 5137 ...    0    0    0]
 ...
 [   1   80  449 ...    1 1307    3]
 [   1   80  449 ...    1 1307    3]
 [   1   80  449 ...    1 1307    3]]
What is a prime number
The uniqueness in this theorem requires excluding 1 as a prime because one can include arbitrarily many instances of 1 in any , e.g.
[[0.4971872 ]
 [0.49897256]
 [0.49899364]
 [0.4987684 ]
 [0.49996978]
 [0.49797595]
 [0.49810866]
 [0.4971872 ]
 [0.49897256]
 [0.49899364]
 [0.4987684 ]
 [0.49996978]
 [0.49797595]
 [0.49810866]
 [0.4971872 ]
 [0.49897256]
 [0.49899364]
 [0.4987684 ]
 [0.49996978]
 [0.49797595]
 [0.49810866]
 [0.4971872 ]
 [0.49897256]
 [0.49899364]
 [0.4987684 ]
 [0.49996978]
 [0.49797595]
 [0.49810866]
 [0.4971872 ]
 [0.49897256]
 [0.49899364]
 [0.4987684 ]
 [0.49996978]
 [0.49797595]
 [0.49810866]
 [0.4971872 ]
 [0.49897256]
 [0.49899364]
 [0.4987684 ]
 [0.49996978]
 [0.49797595]
 [0.49810866]
 [0.4971872 ]
 [0.49897256]
