In [3]:
import spacy
import logging
import re
import pandas as pd
from nltk.metrics import edit_distance

logger = logging.getLogger()
logger.setLevel(logging.WARN)

nlp = spacy.load('en_core_web_lg')

In [54]:
def replace_by_groupname(m):
    
    return "{{{}}}".format(next((k for k, v in m.groupdict().items() if v)))

DISTANCE_THRESHOLD = 200

#TODO: parameterize distance metric
nnc_logger = logging.getLogger("nearest_noun_chunks")
def nearest_noun_chunks(doc, m_subject, m_object):
    
    map_text_into_function = {m_subject: 'm_subject',
                              m_object: 'm_object'}
    
    distances_s = [(m_subject, nc, edit_distance(m_subject, nc.text)) for nc in doc.noun_chunks]
    distances_o = [(m_object, nc, edit_distance(m_object, nc.text)) for nc in doc.noun_chunks]
    
    distances = distances_s + distances_o
    
    if not distances:
        
        raise Exception("doc without sufficient noun chunks: {}".format(doc))
    
    nnc_logger.debug(distances)
    
    min_distance_1 = min(distances, key=lambda v: v[2])
    
    if min_distance_1[2] > DISTANCE_THRESHOLD:
        nnc_logger.warning("distance threshold: {}".format(min_distance_1))
    
    # remove distances from already matched resource
    distances_without_m = [v for v in distances if v[1] != min_distance_1[1] and v[0] != min_distance_1[0]]
    
    if not distances_without_m:
        
        raise Exception("doc without sufficient noun chunks: {}".format(doc))
    
    nnc_logger.debug(distances_without_m)
    
    min_distance_2 = min(distances_without_m, key=lambda v: v[2])
    
    if min_distance_2[2] > DISTANCE_THRESHOLD:
        nnc_logger.warning("distance threshold: {}".format(min_distance_2))

    return {map_text_into_function[min_distance_1[0]]: min_distance_1[1],
            map_text_into_function[min_distance_2[0]]: min_distance_2[1]}


class TemplateExtractor(object):
    
    def __init__(self):
        
        self.logger = logging.getLogger('TemplateExtractor')
    
    def extract_template(self, text, triple):
        
        slots = {}
        
        doc = nlp(text)
        
        ncc = nearest_noun_chunks(doc, triple['m_subject'], triple['m_object'])
        ncc_regex_escaped = {k: re.escape(v.text) for k, v in ncc.items()}
        
        self.logger.debug(ncc)
        
        # is it necessary to compile?
        c = re.compile("((?P<m_subject>{m_subject})|(?P<m_object>{m_object}))".format(**ncc_regex_escaped))

        return Template(c.sub(replace_by_groupname, doc.text), text, triple)
    

#TODO: search python template libraries
class Template(object):
    
    def __init__(self, template_string, text, triple):
        
        self.template_string = template_string
        self.triple = triple
        self.text = text
        
    def fill(self, triple):
        
        return self.template_string.format(**triple)
    
    def __str__(self):
        
        return self.template_string

In [30]:
te = TemplateExtractor()

t = te.extract_template("Eleanor Rigby picks up the rice in the church",
                        {'m_subject': 'Eleanor Rigby', 
                         'm_predicate': 'pick up', 
                         'm_object': 'church'})

In [31]:
t.fill({'m_subject': 'Abelardo',
        'm_predicate': 'drive',
        'm_object': 'car'})

'Abelardo picks up the rice in car'

# Testing

In [38]:
%run ../script/webnlg.py

train = WebNLGCorpus.load('train')
train_1 = train.subset(ntriples=1)

In [55]:
%%time
from collections import defaultdict

logger.setLevel(logging.ERROR)
template_db = defaultdict(set)

lexes_triples = pd.merge(train_1.ldf, train_1.mdf)

te = TemplateExtractor()

for ix, row in lexes_triples.iterrows():
    lexe = row['ltext']
    triple = {'m_subject': row['m_subject'],
              'm_object': row['m_object'],
              'm_predicate': row['m_predicate']
             }
    try:
        t = te.extract_template(lexe, triple)
    
        template_db[row['m_predicate']].add(t)
    except Exception as ex:
        
        pass

CPU times: user 2min 17s, sys: 1.59 s, total: 2min 19s
Wall time: 37.2 s


In [40]:
len(template_db)

225

In [41]:
train_1.mdf.m_predicate.unique().shape

(227,)

In [42]:
data_index = [(m_predicate, len(templates)) for m_predicate, templates in template_db.items()]
data = [d[1] for d in data_index]
index = [d[0] for d in data_index]

stats_on_templates = pd.Series(data=data, index=index)
stats_on_templates.describe()

count    225.000000
mean      18.613333
std       36.900648
min        1.000000
25%        3.000000
50%        6.000000
75%       18.000000
max      290.000000
dtype: float64

In [43]:
stats_on_templates.nlargest(10)

country         290
isPartOf        235
leaderName      219
location        193
club            168
ingredient      131
language        124
runwayLength    106
creator          94
ethnicGroup      74
dtype: int64

In [57]:
from itertools import islice

i = 1
t = list(islice(template_db['cityServed'], i, i+1))[0]

print(t.text)
print()
print(t.triple)
print()
print(t)

The Atlantic City International Airport serves Atlantic City, N.J.

{'m_subject': 'Atlantic_City_International_Airport', 'm_object': 'Atlantic_City,_New_Jersey', 'm_predicate': 'cityServed'}

{m_subject} serves {m_object}, N.J.
