# Memory Representation in Dialogue Systems (Part 2)

Under construction, will update with explanations when finished.

## Process the Text

In [102]:
import pandas as pd
import numpy as np
import nltk
from sklearn.metrics import accuracy_score

In [103]:
# Read the bAbI data as CSV
filename = 'resources/qa2_two-supporting-facts_train.txt'
data = pd.read_csv(filename, delimiter='\t', names=['sentence', 'answer', 'factid'])
data = data.fillna('')

# Tag each sentence as a statement or question
tag_sentence = lambda row: 'S' if row.answer == '' else 'Q'
data['type'] = data.apply(tag_sentence, axis=1)

# Use NLTK to tokenize the sentences into arrays of words
tokenize = lambda row: nltk.word_tokenize(row.sentence)[1:]
data.sentence = data.apply(tokenize, axis=1)

In [104]:
# Create a DataFrame with just the statements
def statements():
    return data[data.type == 'S'] \
        .reset_index(drop=True) \
        .drop('answer', axis=1) \
        .drop('factid', axis=1) \
        .drop('type', axis=1)

# Create a DataFrame with just the questions
def questions():
    return data[data.type == 'Q'] \
        .reset_index(drop=True) \
        .drop('type', axis=1)

In [105]:
# Tag each token as a part of speech
pos_tag = lambda row: nltk.pos_tag(row.sentence)
data['tag'] = data.apply(pos_tag, axis=1)

In [106]:
def extract_statement(tags):
    '''Extracts a (subject, relation, object) triple from each statement based on the POS tags'''
    subject, relation, obj = '', '', ''
    for word,tag in tags:
        if tag == 'NNP':
            subject = word
        elif tag == 'VBD':
            relation = word
        elif tag == 'NN':
            obj = word
    return (subject, relation, obj)

In [178]:
def extract_question(tags):
    '''Extracts the entity under discussion from each question based on the POS tags'''
    eud = ''
    for word,tag in tags:
        if tag == 'NNP' or tag == 'NN':
            eud = word
    return eud

In [179]:
def extract(row):
    '''Extracts the appropriate data given a processed DataFrame row'''
    if row.type == 'S':
        return extract_statement(row.tag)
    else: 
        return extract_question(row.tag)

In [180]:
data['extracted'] = data.apply(extract, axis=1)

In [110]:
def person_statements(person):
    '''Get all statements that refer to the specified person'''
    stat = statements()
    return stat[stat.extracted.map(lambda t: t[0] == person)]

In [111]:
def person_statements_recent(person, n=5):
    '''Get the n most recent statements that refer to the specified person in reverse chronological order'''
    return person_statements(person)[-n:].iloc[::-1]

## Define the Graph

In [302]:
from neo4j.v1 import GraphDatabase, basic_auth

In [303]:
# Create a neo4j session
# NOTE: Make sure that URL/credentials are correct and that Neo4j is running
driver = GraphDatabase.driver('bolt://localhost:7687', auth=basic_auth('neo4j', 'neo4j'))

In [304]:
# WARNING: This will clear the database when run!
def reset_db():
    '''Remove all nodes and relationships from the database'''
    session = driver.session()
    session.run('MATCH (n) DETACH DELETE n')

In [305]:
def create(query, start=0, end=0):
    '''Create a graph based on each triple in the extracted statements'''
    session = driver.session()
    stat = statements()
    end = len(stat) if end <= start else end
    for subject,relation,obj in stat[start:end].extracted:
        session.run(query, subject=subject, relation=relation, obj=obj)

In [345]:
v4_query = '''
    /// 1. Create Nodes
    MERGE (global:GLOBAL {name:'global'}) // Find/create the global entity
    MERGE (subject:SUBJECT {name:$subject}) // Find/create the subject and object
    MERGE (object:OBJECT {name:$obj})

    /// 2. Create a new relation between the subject and object
    CREATE (subject)-[:R_BEGIN]->(relation:RELATION {name:$relation})-[:R_END]->(object)

    /// 3. Create head pointers to the newly created relation
    CREATE (global)-[globalHead:HEAD]->(relation)
    CREATE (subject)-[subjectHead:HEAD]->(relation)
    CREATE (object)-[objectHead:HEAD]->(relation)

    WITH global,subject,relation,object,subjectHead,objectHead,globalHead

    /// 4. Link the existing global list with the new head node
    // Find the previous global head of the list (if none exist, this query will terminate here)
    MATCH (global)-[prevGlobalHead:HEAD]->(prevGlobalRelation:RELATION) WHERE prevGlobalRelation <> relation
    CREATE (prevGlobalRelation)-[:NEXT]->(relation) // Complete the link
    DELETE prevGlobalHead // Remove the previous head pointer

    WITH subject,relation,object,subjectHead,objectHead

    /// 5. Link the existing subject list with the new head node
    // Find the previous subject head of the list (if none exist, this query will terminate here)
    MATCH (subject)-[prevSubjectHead:HEAD]->(prevSubjectRelation:RELATION) WHERE prevSubjectRelation <> relation
    CREATE (prevSubjectRelation)-[:S_NEXT]->(relation) // Complete the link
    DELETE prevSubjectHead // Remove the previous head pointer

    WITH subject,relation,object,objectHead

    /// 6. Link the existing object list with the new head node
    // Find the previous subject head of the list (if none exist, this query will terminate here)
    MATCH (object)-[prevObjectHead:HEAD]->(prevObjectRelation:RELATION) WHERE prevObjectRelation <> relation
    CREATE (prevObjectRelation)-[:O_NEXT]->(relation) // Complete the link
    DELETE prevObjectHead // Remove the previous head pointer
'''

# Represent each relation as a node, ordered by multiple linked lists
def build_v4_graph(start=0, end=0):
    reset_db()
    
    session = driver.session()
    # Create an index for faster access
    session.run('CREATE INDEX ON :SUBJECT(name)')
    session.run('CREATE INDEX ON :RELATION(name)')
    session.run('CREATE INDEX ON :OBJECT(name)')
    create(v4_query, start=start, end=end)

build_v4_graph(start=0, end=12)

In [307]:
statements()[0:16]

Unnamed: 0,sentence,tag,extracted
0,"[Mary, moved, to, the, bathroom, .]","[(Mary, NNP), (moved, VBD), (to, TO), (the, DT...","(Mary, moved, bathroom)"
1,"[Sandra, journeyed, to, the, bedroom, .]","[(Sandra, NNP), (journeyed, VBD), (to, TO), (t...","(Sandra, journeyed, bedroom)"
2,"[Mary, got, the, football, there, .]","[(Mary, NNP), (got, VBD), (the, DT), (football...","(Mary, got, football)"
3,"[John, went, to, the, kitchen, .]","[(John, NNP), (went, VBD), (to, TO), (the, DT)...","(John, went, kitchen)"
4,"[Mary, went, back, to, the, kitchen, .]","[(Mary, NNP), (went, VBD), (back, RB), (to, TO...","(Mary, went, kitchen)"
5,"[Mary, went, back, to, the, garden, .]","[(Mary, NNP), (went, VBD), (back, RB), (to, TO...","(Mary, went, garden)"
6,"[Sandra, went, back, to, the, office, .]","[(Sandra, NNP), (went, VBD), (back, RB), (to, ...","(Sandra, went, office)"
7,"[John, moved, to, the, office, .]","[(John, NNP), (moved, VBD), (to, TO), (the, DT...","(John, moved, office)"
8,"[Sandra, journeyed, to, the, hallway, .]","[(Sandra, NNP), (journeyed, VBD), (to, TO), (t...","(Sandra, journeyed, hallway)"
9,"[Daniel, went, back, to, the, kitchen, .]","[(Daniel, NNP), (went, VBD), (back, RB), (to, ...","(Daniel, went, kitchen)"


In [308]:
all_actions = sorted(list(set(x[1] for x in data.extracted if x != '' and x[1] != '')))

In [309]:
all_actions

['discarded',
 'dropped',
 'got',
 'grabbed',
 'i',
 'journeyed',
 'left',
 'moved',
 'o',
 'p',
 'picked',
 'put',
 'took',
 'travelled',
 'went']

In [310]:
movement_actions = ['journeyed', 'moved', 'travelled', 'went']
acquire_actions = ['got', 'grabbed', 'picked', 'took']
release_actions = ['discarded', 'dropped', 'left', 'put']

In [311]:
def find_last_person(obj):
    '''Finds the last person in contact with the object'''
    query = '''
        MATCH (:OBJECT {name:$name})-[:HEAD]->(relation:RELATION)<-[:R_BEGIN]-(subject:SUBJECT)
        RETURN relation.name AS relation, subject.name AS subject
    '''
    return session.run(query, name=obj)

In [340]:
session = driver.session()
last_person = find_last_person('football').single()
last_person

In [368]:
def find_object_location(subject, relation):
    if relation in acquire_actions:
        query = '''
            MATCH (:SUBJECT {name:$subject})-[:HEAD]->(:RELATION)-[:R_END]->(object:OBJECT)
            RETURN object.name AS obj
            
            //MATCH (:SUBJECT {name:$subject})-[:HEAD]->(relation:RELATION)
            //MATCH p=shortestPath((relation)-[:S_NEXT *]->(nextRelation:RELATION))
            //WHERE nextRelation.name IN $movement
            //MATCH (nextRelation)-[:R_END]->(object:OBJECT)
            //RETURN object.name AS obj
        '''
    else:
        query = '''
            //MATCH (:SUBJECT {name:$subject})-[:R_BEGIN]->(:RELATION)<-[:S_NEXT]-(:RELATION)-[:R_END]->(object:OBJECT)
            //RETURN object.name AS obj
            
            MATCH (:SUBJECT {name:$subject})-[:R_BEGIN]->(relation:RELATION {name:$relation})
            MATCH p=(relation)<-[next:S_NEXT *1..20]-(prevRelation:RELATION)
            WHERE prevRelation.name IN $movement
            WITH size(next) as dist, p
            ORDER BY dist
            WITH filter(n IN nodes(p) WHERE n.name IN $movement)[0] AS shortest
            MATCH (shortest)-[:R_END]->(object:OBJECT)
            RETURN DISTINCT object.name AS obj
        '''
    return session.run(query, subject=subject, relation=relation, movement=movement_actions)

In [369]:
# def find_object_location(obj):
#     if relation in acquire_actions:
#         query = '''
#             MATCH (:OBJECT {name:$obj})-[:HEAD]->(relation:RELATION)<-[:R_BEGIN]-(subject:SUBJECT)
        
#             MATCH (subject)-[:HEAD]->(:RELATION)-[:R_END]->(object:OBJECT)
#             RETURN object.name AS obj
            
#             //MATCH (:SUBJECT {name:$subject})-[:HEAD]->(relation:RELATION)
#             //MATCH p=shortestPath((relation)-[:S_NEXT *]->(nextRelation:RELATION))
#             //WHERE nextRelation.name IN $movement
#             //MATCH (nextRelation)-[:R_END]->(object:OBJECT)
#             //RETURN object.name AS obj
#         '''
#     else:
#         query = '''
#             //MATCH (:SUBJECT {name:$subject})-[:R_BEGIN]->(:RELATION)<-[:S_NEXT]-(:RELATION)-[:R_END]->(object:OBJECT)
#             //RETURN object.name AS obj
            
#             MATCH (:OBJECT {name:$obj})-[:HEAD]->(relation:RELATION)<-[:R_BEGIN]-(subject:SUBJECT)
            
#             MATCH p=(relation)<-[next:S_NEXT *1..20]-(prevRelation:RELATION)
#             WHERE prevRelation.name IN $movement
#             WITH size(next) as dist, p
#             ORDER BY dist
#             WITH filter(n IN nodes(p) WHERE n.name IN $movement)[0] AS shortest
#             MATCH (shortest)-[:R_END]->(object:OBJECT)
#             RETURN DISTINCT object
#         '''
#     return session.run(query, obj=obj, movement=movement_actions)

In [370]:
session = driver.session()
find_object_location('Mary', 'dropped').single()

  warn("Expected a result with a single record, but this result contains %d" % size)


<Record obj='kitchen'>

## Build the Graph

In [63]:
build_v4_graph()

<img src="screenshots/global-and-local-list.png",width=1000>

## Calcualte an Accuracy Score

In [372]:
def get_answers(row):
    '''Given an input row merge the statement in the graph, 
    or query the graph if it is a question'''
    if row.type == 'S':
        subject,relation,obj = row.extracted
        session.run(v4_query, subject=subject, relation=relation, obj=obj)
        return ''
    elif row.type == 'Q':
        obj = row.extracted
        # WARNING: do not consume the result (e.g., call .consume() or .single()) 
        # until the entire iteration is done.
        # Failure to do so may cause the queries to be VERY slow!
        return find_last_person(obj)

In [373]:
reset_db()

In [374]:
session = driver.session()
results = data.apply(get_answers, axis=1)
results = [x.single() for x in results if x != '']
results = [find_object_location(x['subject'], x['relation']) for x in results]
predicted = [x.single()['obj'] for x in results]

  warn("Expected a result with a single record, but this result contains %d" % size)


In [375]:
actual = list(data[data.type == 'Q'].answer)

In [376]:
accuracy_score(actual, predicted)

0.155