# imports

In [1]:
!pip -q install datasets

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/486.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m486.2/486.2 kB[0m [31m31.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.5/212.5 kB[0m [31m28.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.3/134.3 kB[0m [31m19.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m31.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [22]:
from datasets import load_dataset
import numpy as np
from model import MemNN

dataset = load_dataset("facebook/babi_qa", "en-10k-qa1")



  0%|          | 0/2 [00:00<?, ?it/s]

# Preprocess data

In [3]:
def get_dataset(dataset, max_story_length=10):

    # Initialize story and question arrays
    stories = [[] for _ in range(max_story_length)]
    questions = [[] for _ in range(max_story_length)]
    answers = [[] for _ in range(max_story_length)]

    for i in range(len(dataset["train"]["story"])):
        data = dataset["train"]["story"][i]["text"]
        current_ans = dataset["train"]["story"][i]["answer"]
        current_story = []  # Reset the current story for each batch
        current_questions = []  # Reset the current questions for each batch

        for indx, sentence in enumerate(data):
            if "?" in sentence:
                # Determine the length of the current story
                story_length = len(current_story)

                # Append the current story and question to the respective arrays
                if story_length <= max_story_length:
                    stories[story_length - 1].append(list(current_story))
                    questions[story_length - 1].append(sentence)
                    answers[story_length - 1].append(current_ans[indx])
            else:
                current_story.append(sentence)



    # Convert the lists to numpy arrays
    stories = [np.array(story) for story in stories]
    questions = [np.array(question).reshape(-1,1) for question in questions]
    answers = [np.array(answer) for answer in answers]

    return stories, questions, answers


In [4]:
stories, questions, answers = get_dataset(dataset)

In [5]:
questions[1].shape, stories[1].shape

((2000, 1), (2000, 2))

In [6]:
max_story_len = 10

def pad_stories(stories, max_story_len=10):

    batch_size, n_sentence = stories.shape

    if n_sentence < max_story_len:
        padding = np.array([[""] * (max_story_len - n_sentence)] * batch_size)
        stories = np.concatenate((stories, padding), axis=1)

    return stories

merged_stories = np.array([pad_stories(stories[i],max_story_len) for i in range(1,11,2) ]).reshape(-1,max_story_len)
merged_questions = np.array([questions[i] for i in range(1,11,2)]).reshape(-1,1)
merged_answers = np.array([answers[i] for i in range(1,11,2)]).reshape(-1,1)

In [7]:
indx = 16

merged_stories[indx], merged_questions[indx], merged_answers[indx]

(array(['Mary journeyed to the garden.', 'Sandra went back to the office.',
        '', '', '', '', '', '', '', ''], dtype='<U33'),
 array(['Where is Sandra?'], dtype='<U16'),
 array(['office'], dtype='<U8'))

In [8]:
merged_stories.shape, merged_questions.shape, merged_answers.shape

((10000, 10), (10000, 1), (10000, 1))

In [23]:
stories[-1].shape

(2000, 10)

In [10]:
vocab = set()
for s in stories[-1]:
    s = " ".join(s)
    for w in s.split():
        if "." in w:
            w = w[:-1]
        vocab.add(w)
for q in questions[-1]:
    q = q[0]
    for w in q.split():
        if "?" in w:
            w = w[:-1]
        vocab.add(w)

vocab = np.array(list(vocab))


def get_vocab(stories, questions, answers):
    pass

In [11]:
len(vocab)

19

# MeMNet

In [24]:
mem_1 = MemNN(vocab, max_query_len=1, max_story_len=10, vocab_size=1000, embedding_size=128, k=5)
mem_1.batch_size = 32

In [25]:
mem_1.compile()

In [26]:
mem_1.model.summary()

Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 story (InputLayer)             [(32, 10, 128)]      0           []                               
                                                                                                  
 query (InputLayer)             [(32, 1, 128)]       0           []                               
                                                                                                  
 lstm_4 (LSTM)                  (32, 10, 128)        131584      ['story[0][0]']                  
                                                                                                  
 tf.math.l2_normalize_13 (TFOpL  (32, 1, 128)        0           ['query[0][0]']                  
 ambda)                                                                                     

In [27]:
max_story_len = 2

mem_1.fit(stories[max_story_len-1], questions[max_story_len-1].reshape(-1,1), answers[max_story_len-1], epochs=50, validation_split=0.2)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


In [28]:
max_story_len = 4

mem_2 = MemNN(vocab, max_query_len=1, max_story_len=10, vocab_size=1000, embedding_size=128, k=5)
mem_2.batch_size = 32

mem_2.compile()

mem_2.fit(stories[max_story_len-1], questions[max_story_len-1].reshape(-1,1), answers[max_story_len-1], epochs=50, validation_split=0.2)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


In [29]:
max_story_len = 6

mem_3 = MemNN(vocab, max_query_len=1, max_story_len=10, vocab_size=1000, embedding_size=128, k=5)
mem_3.batch_size = 32

mem_3.compile()

mem_3.fit(stories[max_story_len-1], questions[max_story_len-1].reshape(-1,1), answers[max_story_len-1], epochs=50, validation_split=0.2)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


In [None]:
# merged dataset
max_story_len = 6

mem_4 = MemNN(vocab, max_query_len=1, max_story_len=10, vocab_size=1000, embedding_size=128, k=5)
mem_4.batch_size = 32

mem_4.compile()

mem_4.fit(stories[max_story_len-1], questions[max_story_len-1].reshape(-1,1), answers[max_story_len-1], epochs=50, validation_split=0.2)


mem_4.fit(merged_stories, merged_questions, merged_answers, epochs=150, validation_split=0.2, shuffle=True)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50
Epoch 1/150
Epoch 2/150