# Memory Representation in Dialogue Systems (Part 2)

Under construction, will update with explanations when finished.

## Process the Text

In [4]:
import pandas as pd
import numpy as np
import nltk
from sklearn.metrics import accuracy_score

In [5]:
# Read the bAbI data as CSV
filename = 'resources/qa2_two-supporting-facts_train.txt'
data = pd.read_csv(filename, delimiter='\t', names=['sentence', 'answer', 'factid'])
data = data.fillna('')

# Tag each sentence as a statement or question
tag_sentence = lambda row: 'S' if row.answer == '' else 'Q'
data['type'] = data.apply(tag_sentence, axis=1)

# Use NLTK to tokenize the sentences into arrays of words
tokenize = lambda row: nltk.word_tokenize(row.sentence)[1:]
data.sentence = data.apply(tokenize, axis=1)

In [6]:
# Create a DataFrame with just the statements
def statements():
    return data[data.type == 'S'] \
        .reset_index(drop=True) \
        .drop('answer', axis=1) \
        .drop('factid', axis=1) \
        .drop('type', axis=1)

# Create a DataFrame with just the questions
def questions():
    return data[data.type == 'Q'] \
        .reset_index(drop=True) \
        .drop('type', axis=1)

In [7]:
# Tag each token as a part of speech
pos_tag = lambda row: nltk.pos_tag(row.sentence)
data['tag'] = data.apply(pos_tag, axis=1)

In [9]:
def extract_statement(tags):
    '''Extracts a (subject, relation, object) triple from each statement based on the POS tags'''
    subject, relation, obj = '', '', ''
    for word,tag in tags:
        if tag == 'NNP':
            subject = word
        elif tag == 'VBD':
            relation = word
        elif tag == 'NN':
            obj = word
    return (subject, relation, obj)

In [10]:
def extract_question(tags):
    '''Extracts the entity under discussion from each question based on the POS tags'''
    eud = ''
    for word,tag in tags:
        if tag == 'NNP':
            eud = word
    return eud

In [11]:
def extract(row):
    '''Extracts the appropriate data given a processed DataFrame row'''
    if row.type == 'S':
        return extract_statement(row.tag)
    else: 
        return extract_question(row.tag)

In [12]:
data['extracted'] = data.apply(extract, axis=1)

In [13]:
def person_statements(person):
    '''Get all statements that refer to the specified person'''
    stat = statements()
    return stat[stat.extracted.map(lambda t: t[0] == person)]

In [30]:
def person_statements_recent(person, n=5):
    '''Get the n most recent statements that refer to the specified person in reverse chronological order'''
    return person_statements(person)[-n:].iloc[::-1]

## Define the Graph

In [15]:
from neo4j.v1 import GraphDatabase, basic_auth

In [16]:
# Create a neo4j session
# NOTE: Make sure that URL/credentials are correct and that Neo4j is running
driver = GraphDatabase.driver('bolt://localhost:7687', auth=basic_auth('neo4j', 'neo4j'))

In [17]:
# WARNING: This will clear the database when run!
def reset_db():
    '''Remove all nodes and relationships from the database'''
    session = driver.session()
    session.run('MATCH (n) DETACH DELETE n')

In [91]:
def create(query, n=0):
    '''Create a graph based on each triple in the extracted statements'''
    session = driver.session()
    stat = statements()
    n = len(stat) if n < 1 else n
    for subject,relation,obj in stat[:n].extracted:
        session.run(query, { 
            'subject': subject,
            'relation': relation,
            'obj': obj
        })

In [92]:
# Represent each relation as a node, ordered by a linked list (per subject)
def build_v4_graph(n=0):
    reset_db()
    
    v4_query = '''
        /// 1. Create Nodes
        MERGE (global:GLOBAL {name:'global'}) // Find/create the global entity
        MERGE (subject:SUBJECT {name:$subject}) // Find/create the subject and object
        MERGE (object:OBJECT {name:$obj})

        /// 2. Create a new relation between the subject and object
        CREATE (subject)-[:R_BEGIN]->(relation:RELATION {name:$relation})-[:R_END]->(object)
        
        /// 3. Create head pointers to the newly created relation
        CREATE (global)-[globalHead:HEAD]->(relation)
        CREATE (subject)-[subjectHead:HEAD]->(relation)

        WITH global,subject,relation,object,subjectHead,globalHead
        
        /// 4. Link the existing global list with the new head node
        // Find the previous global head of the list (if none exist, this query will terminate here)
        MATCH (global)-[prevGlobalHead:HEAD]->(prevGlobalRelation:RELATION) WHERE prevGlobalRelation <> relation
        CREATE (prevGlobalRelation)-[:NEXT]->(relation) // Complete the link
        DELETE prevGlobalHead // Remove the previous head pointer
        
        WITH subject,relation,object,subjectHead

        /// 5. Link the existing subject list with the new head node
        // Find the previous subject head of the list (if none exist, this query will terminate here)
        MATCH (subject)-[prevSubjectHead:HEAD]->(prevSubjectRelation:RELATION) WHERE prevSubjectRelation <> relation
        CREATE (prevSubjectRelation)-[:S_NEXT]->(relation) // Complete the link
        DELETE prevSubjectHead // Remove the previous head pointer
    '''
    
    session = driver.session()
    # Create an index for faster access
    session.run('CREATE INDEX ON :SUBJECT(name)')
    session.run('CREATE INDEX ON :RELATION(name)')
    session.run('CREATE INDEX ON :OBJECT(name)')
    create(v4_query, n=n)

build_v4_graph(n=10)

In [44]:
def find_person(person):
    '''Find the room a person is currently in'''
    query = '''
        MATCH (s:SUBJECT {name:$name})-[:HEAD]->(r:RELATION)-->(o:OBJECT)
        RETURN s AS subject, r AS relation, o AS obj
    '''
    return session.run(query, {'name': person})

In [45]:
def find_person_history(person, n=100):
    '''Find the list of rooms a person was in, ordered by recency'''
    length = str(n) if n >= 1 else ''
    
    query = '''
        MATCH (s:SUBJECT {name:$name})-[:HEAD]->(r:RELATION)-->(o:OBJECT)
        MATCH (s)-->(r_prev:RELATION)-[k*1..%s]->(r), (r_prev)-->(o_prev:OBJECT)
        
        WITH size(k) AS dist, r, o, r_prev, o_prev
        ORDER BY size(k)
        
        WITH r, o, r_prev, o_prev
        RETURN [r.name] + collect(r_prev.name) AS relation, [o.name] + collect(o_prev.name) AS obj
    '''
    query = query % length
    
    session = driver.session()
    record = session.run(query, {'name': person}).single()
    history = list(zip(record['relation'], record['obj']))[:-1]
    
    return history

## Build the Graph

In [63]:
build_v4_graph()

<img src="screenshots/global-and-local-list.png",width=1000>

## Calcualte an Accuracy Score

In [25]:
def get_answers(row):
    '''Given an input row merge the statement in the graph, 
    or query the graph if it is a question'''
    if row.type == 'S':
        subject,relation,obj = row.extracted
        session.run(v3_query, {
            'subject': subject,
            'relation': relation,
            'obj': obj
        })
        return ''
    elif row.type == 'Q':
        person = row.extracted
        # WARNING: do not consume the result (e.g., call .consume() or .single()) 
        # until the entire iteration is done.
        # Failure to do so may cause the queries to be VERY slow!
        return find_person(person)

In [26]:
reset_db()

In [27]:
session = driver.session()
results = data.apply(get_answers, axis=1)
results = [x for x in results if x != '']
predicted = [result.single()['obj'].get('name') for result in results]

TypeError: 'NoneType' object is not subscriptable

In [28]:
actual = list(data[data.type == 'Q'].answer)

In [29]:
accuracy_score(actual, predicted)

NameError: name 'predicted' is not defined