# Memory Representation in Dialogue Systems (Part 2)

This notebook is part 2 of the dynamic memory representation series. See part 1 to get started.

## Process the Text
As with part 1, part 2 will perform the same evaluation as part 1, except with bAbI tasks QA2, Two Supporting Facts. In QA1, there were two types of entities: `person`s and `room`s. In QA2, there is one additional entity type: `item`s. Each dialogue provides a sequence of statements that indicate persons going to different rooms as before, and also items that persons may have acquired or released. The key insight is that objects move into rooms with the person that last acquired them, and stay in rooms once released. This requires the system to make the distinction between rooms and items, and also between acquiring and releasing actions.

The first step is to import `resources/qa2_two-supporting-facts_train.txt` into `data`. Text processing is exactly the same as before: tokenize and POS tag the sentences.

In [1]:
import pandas as pd
import numpy as np
import nltk
from sklearn.metrics import accuracy_score

In [2]:
# Read the bAbI data as CSV
filename = 'resources/qa2_two-supporting-facts_train.txt'
data = pd.read_csv(filename, delimiter='\t', names=['sentence', 'answer', 'factid'])
data = data.fillna('')

# Tag each sentence as a statement or question
tag_sentence = lambda row: 'S' if row.answer == '' else 'Q'
data['type'] = data.apply(tag_sentence, axis=1)

# Use NLTK to tokenize the sentences into arrays of words
tokenize = lambda row: nltk.word_tokenize(row.sentence)[1:]
data.sentence = data.apply(tokenize, axis=1)

In [3]:
# Create a DataFrame with just the statements
def statements():
    return data[data.type == 'S'] \
        .reset_index(drop=True) \
        .drop('answer', axis=1) \
        .drop('factid', axis=1) \
        .drop('type', axis=1)

# Create a DataFrame with just the questions
def questions():
    return data[data.type == 'Q'] \
        .reset_index(drop=True) \
        .drop('type', axis=1)

In [4]:
# Tag each token as a part of speech
pos_tag = lambda row: nltk.pos_tag(row.sentence)
data['tag'] = data.apply(pos_tag, axis=1)

In [5]:
def extract_statement(tags):
    '''Extracts a (subject, relation, object) triple from each statement based on the POS tags'''
    subject, relation, obj = '', '', ''
    for word,tag in tags:
        if tag == 'NNP':
            subject = word
        elif tag == 'VBD' or word == 'journeyed': # TODO: 'journeyed' is tagged improperly
            relation = word
        elif tag == 'NN':
            obj = word
    return (subject, relation, obj)

In [6]:
def extract_question(tags):
    '''Extracts the entity under discussion from each question based on the POS tags'''
    eud = ''
    for word,tag in tags:
        if tag == 'NNP' or tag == 'NN':
            eud = word
    return eud

In [7]:
def extract(row):
    '''Extracts the appropriate data given a processed DataFrame row'''
    if row.type == 'S':
        return extract_statement(row.tag)
    else: 
        return extract_question(row.tag)

In [8]:
data['extracted'] = data.apply(extract, axis=1)

## Define the Graph

In [11]:
from neo4j.v1 import GraphDatabase, basic_auth

In [12]:
# Create a neo4j session
# NOTE: Make sure that URL/credentials are correct and that Neo4j is running
driver = GraphDatabase.driver('bolt://localhost:7687', auth=basic_auth('neo4j', 'neo4j'))

In [13]:
# WARNING: This will clear the database when run!
def reset_db():
    '''Remove all nodes and relationships from the database'''
    session = driver.session()
    session.run('MATCH (n) DETACH DELETE n')

In [14]:
def create(query, start=0, end=0):
    '''Create a graph based on each triple in the extracted statements'''
    session = driver.session()
    stat = statements()
    end = len(stat) if end <= start else end
    for subject,relation,obj in stat[start:end].extracted:
        session.run(query, subject=subject, relation=relation, obj=obj)

This is the point where QA2 starts to be different from QA1. The query generating the knowledge graph needs to be altered slightly to encode information about the ordering of events relative to objects as well as subjects.

In QA1, a linked list was constructed to keep track of events relative to a character; the `NEXT` edge type indicated the next event that the _person_ acted upon. This was all that was necessary, since the questions asked directly about the most recent event that corresponded to a particular person.

In QA2, questions ask about the item a room is in, which requires a way to keep track of the last person who interacted with it. As such, it is not enough to know the order in which a person performed actions, but it is also necessary to know the order in which an item was handled. The most recent interaction indicates the person who interacted with that object last, and that can be used to find the room based on their visit history.

Thus, the v4 graph query will create three types of lists.
1. The first list is the global list of events indicated by the `NEXT` edge type.
2. The second list is a `person`'s list of events indicated by the `S_NEXT` (next subject) edge type.
3. The third list is an `item`'s list of events indicated by the `O_NEXT` (next object) edge type.
Each list has a `HEAD` edge that points to the most recent event relative to their respective lists.

In [None]:
v4_query = '''
    /// 1. Create Nodes
    MERGE (global:GLOBAL {name:'global'}) // Find/create the global entity
    MERGE (subject:SUBJECT {name:$subject}) // Find/create the subject and object
    MERGE (object:OBJECT {name:$obj})

    /// 2. Create a new relation between the subject and object
    CREATE (subject)-[:R_BEGIN]->(relation:RELATION {name:$relation})-[:R_END]->(object)

    /// 3. Create head pointers to the newly created relation
    CREATE (global)-[globalHead:HEAD]->(relation)
    CREATE (subject)-[subjectHead:HEAD]->(relation)
    CREATE (object)-[objectHead:HEAD]->(relation)

    WITH global,subject,relation,object,subjectHead,objectHead,globalHead

    /// 4. Link the existing global list with the new head node
    // Find the previous global head of the list (if none exist, this query will terminate here)
    MATCH (global)-[prevGlobalHead:HEAD]->(prevGlobalRelation:RELATION) WHERE prevGlobalRelation <> relation
    CREATE (prevGlobalRelation)-[:NEXT]->(relation) // Complete the link
    DELETE prevGlobalHead // Remove the previous head pointer

    WITH subject,relation,object,subjectHead,objectHead

    /// 5. Link the existing subject list with the new head node
    // Find the previous subject head of the list (if none exist, this query will terminate here)
    MATCH (subject)-[prevSubjectHead:HEAD]->(prevSubjectRelation:RELATION) WHERE prevSubjectRelation <> relation
    CREATE (prevSubjectRelation)-[:S_NEXT]->(relation) // Complete the link
    DELETE prevSubjectHead // Remove the previous head pointer

    WITH subject,relation,object,objectHead

    /// 6. Link the existing object list with the new head node
    // Find the previous subject head of the list (if none exist, this query will terminate here)
    MATCH (object)-[prevObjectHead:HEAD]->(prevObjectRelation:RELATION) WHERE prevObjectRelation <> relation
    CREATE (prevObjectRelation)-[:O_NEXT]->(relation) // Complete the link
    DELETE prevObjectHead // Remove the previous head pointer
'''

In [15]:
# Represent each relation as a node, ordered by multiple linked lists
def build_v4_graph(start=0, end=0):
    reset_db()
    
    session = driver.session()
    
    # Create an index for faster access
    session.run('CREATE INDEX ON :SUBJECT(name)')
    session.run('CREATE INDEX ON :RELATION(name)')
    session.run('CREATE INDEX ON :OBJECT(name)')
    
    create(v4_query, start=start, end=end)

In [16]:
all_actions = sorted(list(set(x[1] for x in data.extracted if x != '' and x[1] != '')))

In [17]:
movement_actions = ['journeyed', 'moved', 'travelled', 'went']
acquire_actions = ['got', 'grabbed', 'picked', 'took']
release_actions = ['discarded', 'dropped', 'left', 'put']

In [18]:
def find_last_person(obj):
    '''Finds the last person in contact with the object'''
    query = '''
        MATCH (:OBJECT {name:$name})-[:HEAD]->(relation:RELATION)<-[:R_BEGIN]-(subject:SUBJECT)
        RETURN relation.name AS relation, subject.name AS subject
    '''
    return session.run(query, name=obj)

In [20]:
def find_object_location(obj):
    query = '''
        // Find the last person in contact with the object
        MATCH (:OBJECT {name:$obj})-[:HEAD]->(relation:RELATION)<-[:R_BEGIN]-(subject:SUBJECT)

        // Acquire
        MATCH (subject)-[:HEAD]->(head_relation:RELATION)
        
        MATCH p=(head_relation)<-[next:S_NEXT *1..20]-(prevRelation:RELATION)
        WHERE prevRelation.name IN $movement
        WITH size(next) as dist, p, relation
        ORDER BY dist
        WITH filter(n IN nodes(p) WHERE n.name IN $movement)[0] AS shortest, relation
        MATCH (shortest)-[:R_END]->(object_acquire:OBJECT)
        
        WITH relation, object_acquire

        // Release
        MATCH p=(relation)<-[next:S_NEXT *1..20]-(prevRelation:RELATION)
        WHERE prevRelation.name IN $movement
        WITH size(next) as dist, p, object_acquire, relation
        ORDER BY dist
        WITH filter(n IN nodes(p) WHERE n.name IN $movement)[0] AS shortest, object_acquire, relation
        MATCH (shortest)-[:R_END]->(object_release:OBJECT)

        RETURN DISTINCT object_acquire.name AS acquire, object_release.name AS release, relation.name AS relation
    '''
    return session.run(query, obj=obj, movement=movement_actions)

In [21]:
## 98.6% Accuracy

# def find_object_location(obj):
#     query = '''
#         // Find the last person in contact with the object
#         MATCH (:OBJECT {name:$obj})-[:HEAD]->(relation:RELATION)<-[:R_BEGIN]-(subject:SUBJECT)

#         // Acquire
#         MATCH (subject)-[:HEAD]->(:RELATION)-[:R_END]->(object_acquire:OBJECT)

#         // Release
#         MATCH p=(relation)<-[next:S_NEXT *1..20]-(prevRelation:RELATION)
#         WHERE prevRelation.name IN $movement
#         WITH size(next) as dist, p, object_acquire, relation
#         ORDER BY dist
#         WITH filter(n IN nodes(p) WHERE n.name IN $movement)[0] AS shortest, object_acquire, relation
#         MATCH (shortest)-[:R_END]->(object_release:OBJECT)

#         RETURN DISTINCT object_acquire.name AS acquire, object_release.name AS release, relation.name AS relation
#     '''
#     return session.run(query, obj=obj, movement=movement_actions)

In [24]:
build_v4_graph(start=0, end=6)

session = driver.session()
find_object_location('football').single()

<Record acquire='garden' release='bathroom' relation='got'>

In [25]:
session = driver.session()
find_object_location('football').single()

<Record acquire='garden' release='bathroom' relation='got'>

## Build the Graph

In [26]:
build_v4_graph()

<img src="screenshots/qa2-multiple-list.png" style="width:1000px;">

## Calcualte an Accuracy Score

In [27]:
def get_answers(row):
    '''Given an input row merge the statement in the graph, 
    or query the graph if it is a question'''
    if row.type == 'S':
        subject,relation,obj = row.extracted
        session.run(v4_query, subject=subject, relation=relation, obj=obj)
        return ''
    elif row.type == 'Q':
        obj = row.extracted
        # WARNING: do not consume the result (e.g., call .consume() or .single()) 
        # until the entire iteration is done.
        # Failure to do so may cause the queries to be VERY slow!
        return find_object_location(obj)

In [28]:
def traverse(result):
    if result['relation'] in acquire_actions:
        return result['acquire']
    else:
        return result['release']

In [30]:
reset_db()

In [31]:
session = driver.session()
results = data.apply(get_answers, axis=1)
results = [x for x in results if x != '']
predicted = [traverse(result.single()) for result in results]

In [32]:
predicted[:5]

['garden', 'garden', 'hallway', 'hallway', 'kitchen']

In [33]:
actual = list(questions().answer)

In [34]:
actual[:5]

['garden', 'garden', 'hallway', 'hallway', 'kitchen']

In [35]:
accuracy_score(actual, predicted)

1.0

In [36]:
def find_incorrect(actual, predicted):
    z = list(zip(actual, predicted))
    return [(i, x[0], x[1]) for i,x in enumerate(z) if x[0] != x[1]]