In [1]:
from flair.models import SequenceTagger

from flair.data import Sentence

import graphene
from elasticsearch import Elasticsearch
from elasticsearch.helpers import scan
from elasticsearch_dsl import Q, Search
import numpy as np
import pickle
import string
import sys
import timeit

from itertools import islice
from collections import deque, defaultdict, Counter

from nltk import ChartParser
from nltk.parse.generate import generate
from nltk.grammar import CFG, Nonterminal
from nltk.tree import Tree
from semantic.numbers import NumberService

HOST = "http://18.203.139.101:9200"
INDEX = "company"

client = Elasticsearch(HOST+"/")

## TODO:
* implement fallback search that just concats all fields and looks for query terms in there
* implement range queries
* huge stretch: implement logical functionality (and/or/not)

### Load Word Counts

### Use the above field values to make a simple query-parsing grammer

In [2]:
number_parser = NumberService()

ner_tagger = SequenceTagger.load('ner-ontonotes-fast')
pos_tagger = SequenceTagger.load("pos-fast")

comparator_expressions = {
        "gte": ["greater than or equal to", "minimum", "min", "at least", "not less than", "not fewer than", ">=",
                "no less than", "no fewer than"
               ],
        "gt": ["greater than", "more than", ">", "above", "over"],
        "lte": ["less than or equal to", "maximum", "max", "at most", "not greater than", "not more than", "<=",
                "no more than"
               ],
        "lt": ["less than", "fewer than", "<", "below", "under"]
    }

reverse_comparator_expressions = {v:k for k,vs in comparator_expressions.items() for v in vs}

with open("company_grammar.pkl", 'rb') as f:
    company_grammar = pickle.load(f)

2019-07-15 17:55:55,477 loading file /Users/andrei-alinpopescu/.flair/models/en-ner-ontonotes-fast-v0.3.pt
2019-07-15 17:56:04,510 loading file /Users/andrei-alinpopescu/.flair/models/en-pos-ontonotes-fast-v0.2.pt


In [3]:
def replace_tags(original_sentence, tagged_sentence, tagger_type, tags_to_replace = None):
    ''' Replace natural language sequence with its found tag
    :param original_sentence: str
    :param tagged_sentence: flair Sentence object that as been tagged by a Tagger
    :param tagger_type: str which tagger has been used 'ner' or 'pos'
    :param tags_to_replace: List[str] which tags found by Tagger to replace. If None, replace all tags
    
    :return : Tuple[str, Dict[str, deque[str]]] returns the sentence with tags replaced and the original values that
              the tags had in the text. This latter return is used when parsing the syntax tree to re-populate the
              initial textual values. Processing them in a FIFO manner ensures consistency when doing this (so initial
              values get mapped to the correct syntax tree leaves)
    '''

    tags = {}

    tagged_string = ""
    prev_end = 0
    
    print(tagged_sentence.to_dict(tagger_type)['entities'])
    
    for entity in tagged_sentence.to_dict(tagger_type)['entities']:
        start_pos = entity['start_pos']
        end_pos   = entity['end_pos']
        tag_type = f"<{entity['type'].lower()}>"
        if tags_to_replace is not None and tag_type not in tags_to_replace:
            continue
        if tag_type not in tags:
            tags[tag_type] = deque()
        tags[tag_type].append(original_sentence[start_pos:end_pos])

        tagged_string += original_sentence[prev_end:start_pos]
        tagged_string += tag_type
        prev_end = end_pos
    tagged_string += original_sentence[prev_end:]

    return tagged_string, tags


def preprocess_sentence(original_sentence):
    ''' Preprocess a free text query to something our grammar can handle. This works as follows:
        * do POS tagging on the initial sequence, replacing all found cardinals (this is because NER seems to have
          some strange assumptions around what a cardinal is)
        * from the above result, replace comparators with a canonical representation (e.g. 'fewer than' -> '<')
        * remove uninformative words such as interjections and conjunctions
        * do NER tagging to identify countries and regions
    
    :param : str the original sentence
    
    :return : Tuple[str, Dict[str, deque[str]]], return the tag-replaced sentence together with the actual initial tag values
    '''

    sentence = Sentence(original_sentence)
    pos_tagger.predict(sentence)
    pos_sentence, ca_tags = replace_tags(original_sentence, 
                                         sentence, 
                                         tagger_type="pos", 
                                         tags_to_replace=["<cd>"]
                                        )

    if '<cd>' in ca_tags:
        ca_tags['<cardinal>'] = ca_tags['<cd>']
        del ca_tags['<cd>']
        
    print(sentence.to_tagged_string())

    pos_sentence = pos_sentence.replace("<cd>", "<cardinal>")
    
    for matcher, comparator in sorted(reverse_comparator_expressions.items(), key=lambda a: -len(a[0])):
        pos_sentence = pos_sentence.replace(matcher +" <cardinal>", comparator+" <cardinal>")
    
    pos_sentence = pos_sentence.replace("between <cardinal> and <cardinal>", 
                                        "<range_start> <cardinal> <range_end> <cardinal>"
                                       )
        
    sentence = Sentence(pos_sentence)
    pos_tagger.predict(sentence)
    pos_sentence,_ = replace_tags(pos_sentence,
                                  sentence,
                                  tagger_type="pos",
                                  tags_to_replace=["<in>", "<cc>", "<vbg>", "<dt>"]
                                 )

    pos_sentence = pos_sentence.replace("<in>", "").replace("<cc>", "").replace("<vbg>", "").replace("<dt>", "")
    
    sentence = Sentence(pos_sentence)
    ner_tagger.predict(sentence)

    ner_sentence, ner_tags = replace_tags(pos_sentence, 
                                          sentence, 
                                          tagger_type="ner",
                                          tags_to_replace=["<gpe>", "<loc>"]
                                         )
    ner_tags.update(ca_tags)

    return ner_sentence, ner_tags

In [4]:
def _process_parse_tree(tree, tags):
    ''' Given a syntax tree and a mapping from tags to initial value, process the syntax tree in a DFS fashion to
        populate the with the original comparison values.
        
        Because all syntax trees have the same leaf order regardless of their actual structure, and because DFS
        traversal retrieves this order, we can safely populate the initial values in the order in which they are
        encountered in this traversal by popping from the appropriate tags queue. E.g. when we encounter a 'cardinal'
        leaf it is always safe to pop the first value in the tags['cardinal'] queue and label the respective 
        leaf with that value.
    '''
    
    ret = {}
    tag_terminals={"<cardinal>", "<gpe>", "<loc>", "<money>"}
    redundant_terminals={"employees", "revenue"}
    def dfs_label(root=tree):
        if isinstance(root, str):
            if root in tag_terminals:
                return tags[root].popleft()
            if root not in redundant_terminals:
                return root
            else:
                return None

        new_children = []
        for el in root:
            labelled_child = dfs_label(el)
            if labelled_child is not None:
                new_children.append(labelled_child)
        return Tree(root.label(), new_children)

    tree = dfs_label()
    return tree

def _populate_query(tree):
    ''' Given a labelled parse tree (i.e. with tags replaces with their original values), attempt to
        generate a mapping of the form field -> condition, which can be used to create an ES querystring
    '''
    ret = {}

    def process_range_subtree(root):
        return [
         ("gte",number_parser.parse(root[1])),
         ("lte",number_parser.parse(root[3]))
        ]
    
    def process_comparator_subtree(root):
        co, ca = None, None
        for child in root:
            if isinstance(child, Tree) and child.label() == 'Ineq':
                co = child[0]
            else:
                ca = child
        return [(co, number_parser.parse(ca))]
        
    def process_numerical_field_subtree(root):
        for child in root:
            if isinstance(child, Tree) and child.label() == 'CO':
                return process_comparator_subtree(child)
            if isinstance(child, Tree) and child.label() == 'Range':
                return process_range_subtree(child)
    
    def process_companytype_subtree(root):
        if isinstance(root, Tree):
            for child in root:
                for kw in process_companytype_subtree(child):
                    yield kw
        else:
            yield root

    def dfs_query_parser(root=tree):
        for child in root:
            if child.label() == 'Link':
                continue
            if child.label() == 'L':
                # TODO: parse L so we distinguish between various location-type queries
                continue
            if child.label() == 'CT':
                company_freetext = " ".join([el for el in process_companytype_subtree(child)])
                ret['industry_class_code_desc'] = company_freetext
                ret['company_description'] = company_freetext
            elif child.label() == 'REGION':
                if 'region' in ret:
                    raise ValueError("Can only have one region identifier")
                ret['region'] = list(child)[0]
            elif child.label() == 'COUNTRY':
                if 'country' in ret:
                    raise ValueError("Can only have one country identifier")
                ret['country'] = list(child)[0]
            elif child.label() == 'EMP':
                if 'number_of_employees' not in ret:
                    ret['number_of_employees'] = []
                ret['number_of_employees'].extend(process_numerical_field_subtree(child))
            elif child.label() == 'REV':
                if 'latest_revenue' not in ret:
                    ret['latest_revenue'] = []
                ret['latest_revenue'].extend(process_numerical_field_subtree(child))
            else:
                dfs_query_parser(child)

    dfs_query_parser()
    return ret

def _build_query_string(parse_dict):
    ''' Given a parse dictionary, produce an ES query string.
    '''
    ret = []
    free_text_company = []
    for field, parse_value in parse_dict.items():
        if field == 'latest_revenue' or field == 'number_of_employees':
            ret.append(f"{field}:({' AND '.join([''.join(co_ca) for co_ca in parse_value])})")
        elif field == 'company_description' or field == 'industry_class_code_desc':
            free_text_company.append("{}:({})".format(field, parse_value))
        else:
            ret.append("{}:({})".format(field, parse_value))
    
    ret.append("("+" OR ".join(free_text_company)+")")
    
    return " AND ".join(ret)

def _fallback_search(input_query, size=10):
    ''' Fall back on this in case nice query processing fails.
    '''
    print("Falling back on free text search")
    s = (Search()
         .using(client)
         .query(
             "query_string",
             fields=["country.keyword", "industry_class_code_desc", "company_description", "region"],
             type="cross_fields",
             query=input_query,
             fuzzy_prefix_length=3
         )
    )
    return s[:size]

def free_text_query(input_query, size=10):
    # N.B flair taggers are the execution bottleneck at the moment
    tagged_sentence, tags = preprocess_sentence(input_query)
    tagged_sentence = tagged_sentence.lower()
    
    parser = ChartParser(company_grammar)
    grammar_tree = None
    try:
        for tree in parser.parse(tagged_sentence.split()):
            grammar_tree = tree
    except Exception as e:
        print("Failed to parse grammar", e)
        return _fallback_search(input_query, size)
        
    if grammar_tree is None:
        print("No found grammar trees")
        return _fallback_search(input_query, size)

    processed_parse_tree = _process_parse_tree(grammar_tree, tags)
    parse_dict = _populate_query(processed_parse_tree)
    
    s = Search(index=INDEX).using(client)
    
    if "region" in parse_dict:
        s=s.query("match", region=parse_dict["region"])
        
    
    s=s.query("multi_match",
             fields = ["company_description", "industry_class_code_desc"],
             query=parse_dict["industry_class_code_desc"],
            )

    
    if "number_of_employees" in parse_dict:
        s=s.filter("range", number_of_employees=dict(parse_dict["number_of_employees"]))
    if "latest_revenue" in parse_dict:
        s=s.filter("range", latest_revenue=dict(parse_dict["latest_revenue"]))
    if "country" in parse_dict:
        s=s.filter({"terms": {"country.keyword": [parse_dict["country"]]}})

    print(s.to_dict())
    print()
    
    if s.count() == 0:
        print("Empty Results")
        return _fallback_search(input_query, size)
    return s[:size]

In [7]:
input_query = "drug stores in Europe with more than 55 employees and revenue less than 5"

In [8]:
for hit in free_text_query(input_query, size=100):
    print("score: ", hit.meta['score'])
    print("Name: ", hit.company_name)
    print("Company id: ", hit.company_id)
    print("Country: ", hit.country if hasattr(hit, "country") else "Empty")
    print("Revenue: ", hit.latest_revenue if hasattr(hit, "latest_revenue") else "Empty")
    print("Number of employees: ", hit.number_of_employees if hasattr(hit, "number_of_employees") else "Empty")
    print("Industry :", hit.industry_class_code_desc if hasattr(hit, "industry_class_code_desc") else "Empty")
    print("Company description: ", hit.company_description if hasattr(hit, "company_description") else "Empty")
    print("#"*100)

[{'text': 'drug', 'start_pos': 0, 'end_pos': 4, 'type': 'NN', 'confidence': 0.9994885921478271}, {'text': 'stores', 'start_pos': 5, 'end_pos': 11, 'type': 'NNS', 'confidence': 0.9993044137954712}, {'text': 'in', 'start_pos': 12, 'end_pos': 14, 'type': 'IN', 'confidence': 0.9999935626983643}, {'text': 'Europe', 'start_pos': 15, 'end_pos': 21, 'type': 'NNP', 'confidence': 0.9996509552001953}, {'text': 'with', 'start_pos': 22, 'end_pos': 26, 'type': 'IN', 'confidence': 0.9999909400939941}, {'text': 'more', 'start_pos': 27, 'end_pos': 31, 'type': 'JJR', 'confidence': 0.9984748959541321}, {'text': 'than', 'start_pos': 32, 'end_pos': 36, 'type': 'IN', 'confidence': 0.9999995231628418}, {'text': '55', 'start_pos': 37, 'end_pos': 39, 'type': 'CD', 'confidence': 0.9999997615814209}, {'text': 'employees', 'start_pos': 40, 'end_pos': 49, 'type': 'NNS', 'confidence': 0.9994364380836487}, {'text': 'and', 'start_pos': 50, 'end_pos': 53, 'type': 'CC', 'confidence': 1.0}, {'text': 'revenue', 'start_po

Industry : Variety stores
Company description:  Variety Stores
####################################################################################################
score:  6.663927
Name:  Danielsen Sko AS
Company id:  593036203
Country:  Norway
Revenue:  4.9161789254462
Number of employees:  57
Industry : Shoe stores
Company description:  Shoe Stores
####################################################################################################
score:  6.663927
Name:  Schuurman Groep B.V.
Company id:  46540676
Country:  Netherlands
Revenue:  0.3850185164403546
Number of employees:  289
Industry : Shoe stores
Company description:  Shoe Stores
####################################################################################################
score:  6.663927
Name:  Praxis Sociedad Comercializadora Sa
Company id:  260915055
Country:  Spain
Revenue:  3.873799798002469
Number of employees:  61
Industry : Book stores
Company description:  Book Stores
###################################