# Old stuff

In [6]:
import spacy
spacy.load('en_core_web_md')
from epitator.annotator import AnnoDoc
from epitator.geoname_annotator import GeonameAnnotator
from epitator.resolved_keyword_annotator import ResolvedKeywordAnnotator
from epitator.count_annotator import CountAnnotator
from epitator.date_annotator import DateAnnotator
from boilerpipe.extract import Extractor
from itertools import groupby
import datetime
import sys
import pandas as pd
import re
from tqdm import tqdm_notebook as tqdm
import numpy as np
import epitator

In [5]:
import pandas as pd
type(pd.DataFrame({"a":[1],"b":[2]})) == pd.DataFrame

True

## Read in Ereignisdatenbank

In [16]:
ereignisdatenbank = pd.read_csv("Ereignisse_utf8.csv",sep=";")

In [17]:
ereignisdatenbank.columns = list(map(lambda x:x.strip(" "),ereignisdatenbank.columns))

In [18]:
countries_not_null = ereignisdatenbank[pd.notnull(ereignisdatenbank["Ausgangs- bzw. Ausbruchsland"])]

In [19]:
countries = countries_not_null["Ausgangs- bzw. Ausbruchsland"].copy(deep=True)
countries = list(map(lambda x:x.strip(" "),countries))

In [20]:
set(countries)

{'Afghanistan',
 'Afghanistan,\nDR Congo\nNigeria\nSomalia',
 'Algerien',
 'Angola',
 'Australien',
 'Bangladesch',
 'Benin',
 'Bolivien',
 'Brasilien',
 'Burundi',
 'China',
 'Costa Rica',
 'DRC',
 'DRCongo',
 'Demokratische Republik Kongo',
 'Deutschland',
 'El Salvador',
 'Fiji',
 'Frankreich',
 'Französich_Polynesien',
 'Französiche Guyana',
 'Französisch-Polynesien',
 'Ghana',
 'Haiti',
 'Indien',
 'Irak',
 'Iran',
 'Israel',
 'Italien',
 'Italien, Griechenland, Rumanien, Ungarn, Frankreich',
 'Italien, Griechenland, Ungarn, Rumänien',
 'Italien, Serbien, Griechenland, Rumänien, Ungarn, Frankreich, Kosovo, Albanien, Macedonien, Montenegro, Serbien, Türkei',
 'Kamerun',
 'Kanada',
 'Kenia',
 'Kolumbien',
 'Kongo',
 'Kroatien',
 'Kuwait',
 'La Reunion',
 'Liberia',
 'Madagaskar',
 'Malawi',
 'Mali',
 'Mauretanien',
 'Mosambik',
 'Myanmar',
 'Namibia',
 'Niger',
 'Nigeria',
 'Nordeuropa',
 'Oman',
 'Pakistan',
 'Papua-Neuguinea',
 'Peru',
 'Philippinen',
 'Polen',
 'Saudi-Arabien',
 

# Let the testing begin
## Parser for Wikipedia table of countries

In [1]:
from bs4 import BeautifulSoup
import requests

In [2]:
req = requests.get("https://de.wikipedia.org/wiki/Liste_der_Staaten_der_Erde")
soup = BeautifulSoup(req.content,"html.parser")

In [3]:
parsed_soup = soup.find("table",class_="wikitable sortable zebra").find("tbody") # Find table with of all countries
parsed_soup = parsed_soup.find_all("tr") # Get entries of countries form table
amount_countries = len(parsed_soup)
parsed_soup = [parsed_soup[i].find_all('td') \
               for i in range(amount_countries)] # Extract table entries from country entry

In [14]:
wiki_dict = {"state_name_de":[],
             "full_state_name_de":[],
             "capital_de":[],"translation_state_name":[],
             "wiki_abbreviations":[]}
dash = u"\u2014" # Used dash for missing entry in Wikipedia table
regex = re.compile(r"\[\d*\]") # To remove footnotes in the names
for i in range(amount_countries):
    try:
        state_name_de = regex.sub("",(parsed_soup[i][0].text).replace("\n","")\
                                  .replace("\xad","")) # Remove soft hyphen used in Zentralafr. Rep.
        
        # Remove additional information that are note the state name
        state_name_de = re.sub(r"((mit)|(ohne)).*","",state_name_de) 
        wiki_dict["state_name_de"].append(state_name_de) 
        
        # Removes new lines
        wiki_dict["full_state_name_de"].append(regex.sub("",parsed_soup[i][1].text).replace("\n","")) 
        wiki_dict["capital_de"].append(regex.sub("",parsed_soup[i][2].text).replace("\n","")) 
        wiki_dict["translation_state_name"].append(regex.sub("",parsed_soup[i][10].text).replace("\n",""))
        
        # Also removes new lines. Column 7 and 8 are long and short official abbreviations for the countries
        list_abbreviation = [parsed_soup[i][7].text.replace("\n",""),parsed_soup[i][8].text.replace("\n","")] 
        
        # Remove empty abbrev. entries. E.g. Abchasien | ["ABC", "-"]| Abkhazia --> Abchasien |"ABC" | Abhkazia
        list_abbreviation = list(filter(lambda x: x not in ["",dash],list_abbreviation)) 
        if len(list_abbreviation) > 1:
            wiki_dict["wiki_abbreviations"].append(list_abbreviation)
        else:
            # When after removal of empty entries no abbrev. remains, enter a single dash
            wiki_dict["wiki_abbreviations"].append(dash)
    except IndexError as e: # Because header and footer are part of the table, soup opperations don't work
        if i not in [0, 213]:
            print("Entry {} failed unexpected because of {}".format(i,e)) # Except that the first and last entry fail

In [10]:
wikipedia_country_list = pd.DataFrame.from_dict(wiki_dict)
wikipedia_country_list.head()

Unnamed: 0,state_name_de,full_state_name_de,capital_de,translation_state_name,wiki_abbreviations
0,Erde,—,—,Earth,—
1,Europäische Union,—,Brüssel,European Union,—
2,Union Südamerikanischer Nationen,—,Quito,Union of South American Nations,—
3,Afrikanische Union,—,Addis Abeba,African Union,—
4,Verband Südostasiatischer Nationen,—,Jakarta,Association of Southeast Asian Nations,—


In [8]:
def abbreviate(country_name):
    """Abbreviates entries of list of country names
    
    Example: United Kingdom --> UK
    """
    country_name = re.sub(r"\(.*\)","", country_name) # Delete content in paranthesis since not relevant for abbrev.
    if "," in country_name:
        # If there is a comma, switch order to yield a more common abbreviation: Korea, Nord --> Nord Korea
        matched = re.match(r"([A-Za-z]*), (.*)",country_name) # Extract capital letters
        country_name = matched[2] + " " + matched[1] # Patch capital letters together
    abbreviation = None
    if len(re.findall(r"([A-Z|Ä|Ö|Ü])",country_name)) > 1:
        abbreviation = "".join(re.findall(r"([A-Z|Ä|Ö|Ü])",country_name))
    return abbreviation

In [528]:
# Search for names that might have abbreviations. If they consist of two or more words that start with a capital
# letter, it makes an abbreviation out of it
abb_state_de = list(map(abbreviate,wikipedia_country_list["state_name_de"].tolist()))
abb_full_state_de = list(map(abbreviate,wikipedia_country_list["full_state_name_de"].tolist()))
abb_state_trans = list(map(abbreviate,wikipedia_country_list["translation_state_name"].tolist()))
                   
abbreviations = [list(a) for a in zip(abb_state_de,abb_full_state_de,abb_state_trans)]
abbreviations = [list(filter(None,abb)) for abb in abbreviations if str(abb) != 'None'] # Removes Nones
abbreviations = list(map(lambda x: list(set(x)) if len(x)>0 else "-", abbreviations)) # Removes redundance
#abbreviations = list(map(", ".join,cleaned_abbreviations)) # Unpack list of abbreviations to string

In [530]:
wikipedia_country_list["inoff_abbreviations"] =  abbreviations
wikipedia_country_list.head()

Unnamed: 0,state_name_de,full_state_name_de,capital_de,translation_state_name,wiki_abbreviations,inoff_abbreviations
0,Erde,—,—,Earth,—,-
1,Europäische Union,—,Brüssel,European Union,—,[EU]
2,Union Südamerikanischer Nationen,—,Quito,Union of South American Nations,—,"[USN, USAN]"
3,Afrikanische Union,—,Addis Abeba,African Union,—,[AU]
4,Verband Südostasiatischer Nationen,—,Jakarta,Association of Southeast Asian Nations,—,"[VSN, ASAN]"


## Ontology/Comparison (Transformed to .py until here)

In [531]:
from epitator.annotator import AnnoDoc
from epitator.geoname_annotator import location_contains
doc = AnnoDoc("I live in Munic!")
doc.add_tiers(GeonameAnnotator())
annotations = doc.tiers["geonames"]
geoname = annotations[0]
geoname.to_dict()

{'label': 'Munich',
 'textOffsets': [[10, 15]],
 'geoname': {'geonameid': '2867714',
  'name': 'Munich',
  'feature_code': 'PPLA',
  'country_code': 'DE',
  'admin1_code': '02',
  'admin2_code': '091',
  'admin3_code': '09162',
  'admin4_code': '09162000',
  'longitude': 11.57549,
  'latitude': 48.13743,
  'population': 1260391,
  'asciiname': 'Munich',
  'names_used': 'Munic',
  'name_count': 88,
  'country_name': 'Federal Republic of Germany',
  'admin1_name': 'Bavaria',
  'admin2_name': 'Upper Bavaria',
  'admin3_name': 'Kreisfreie Stadt München',
  'parents': [],
  'score': 0.21453999698331166}}

In [532]:
#To .py and renamed to clean_country_names
def clean_entries(countries):
    """Takes a list of countries (from Ereginsdatenbank) and returns a set of cleaned country names"""
    card_dir = re.compile(r"(Süd|Nord|West|Ost)\s(\S*)") # Matches cardinal directions and the string after it
    countries_unique = list(set(countries)) # Optional. Used for better overview and faster calculation
    
    # Because someone used new lines in entries instead of comma to list countries
    countries_unique = list(map(lambda x: re.sub(r'\n',', ',x), countries_unique))
    
    # Because the line above adds one comma to much
    countries_unique = list(map(lambda x: re.sub(r',,',',',x), countries_unique)) 
    countries_unique = list(map(lambda x: re.sub(r'\(.*\)',"",x).strip(" "), countries_unique))
    countries_unique = list(map(lambda x: x.replace("&", "und"), countries_unique))
    countries_unique = list(map(lambda x: x.split(",") if "," in x else x, countries_unique)) # For entries with more than one country
    countries_unique = list(map(lambda x: x.replace("_"," ") if type(x) != list else x,countries_unique))
    
    # To transform Süd Sudan to Südsudan
    try:
        countries_unique = list(map(lambda x: card_dir.match(x)[1] + card_dir.match(x)[2].lower()\
                                    if type(x) != list and card_dir.match(x) else x, countries_unique ))
    except IndexError:
        print(card_dir.match, " has a cardinal direction but is not of the form 'Süd Sudan'")
    
    #"Recursively" clean lists
    countries_unique = list(map(lambda x: clean_entries(x) if type(x) == list else x,countries_unique))
    return countries_unique

In [535]:
# Test for clean_entries()
from deep_eq import deep_eq
clean_entries(countries)
example_countries_to_clean = [" Australien",
                              "Kongo \nUSA",
                              "Italien, Deutschland, Belgien ",
                              "Franz._Polynesien", 
                              "Trinidad & Tobago"]
expected_countries_to_clean = ["Trinidad und Tobago","Australien"
                               ,['Belgien', 'Deutschland', 'Italien']
                               ,["USA", "Kongo"], "Franz. Polynesien"]
if deep_eq(clean_entries(example_countries_to_clean),expected_countries_to_clean):
    print("Test succesful")
else:
    print("Test failed")

Test succesful


In [None]:
# # FOR TESTING. RETURNS TUPLE WITH ABBREVIATION AND TRANSLATION
# # Takes a list of not matched/translated entries and tries to match them to the wikipedia table and find the full name
# countries_not_translated = [entry for entry in countries_unique \
#                             if entry not in wikipedia_country_list["state_name_de"].tolist()]
# def translate_abbreviation(to_translate):
#     abb_to_country = []
#     if type(to_translate) == str:
#         to_translate = [to_translate]
#     for column in ["wiki_abbreviations","inoff_abbreviations"]:
#         for potential_abbreviation in to_translate:
#             if type(potential_abbreviation) == str:
#                 for i, abbreviation in enumerate(wikipedia_country_list[column]):
#                     if potential_abbreviation in abbreviation:
#                         abb_to_country.append((potential_abbreviation,\
#                                                wikipedia_country_list["translation_state_name"].tolist()[i]))
#                         to_translate.remove(potential_abbreviation)
#             elif type(potential_abbreviation) == list:
#                 abb_to_country.append(translate_abbreviation(potential_abbreviation))
#     return(abb_to_country,to_translate)

# abbreviation_tuple, countries_not_translated = translate_abbreviation(countries_not_translated)

# #abbreviation_tuple
# #print("****************************************************************************************************************")
# #print(countries_not_translated)

In [536]:
#TODO if abbreviation found, don't continue searching
def translate_abbreviation(to_translate):
    """Takes a list of countries and/or abbreviations and translates the abbreviations to the full state name"""
    to_return = []
    if type(to_translate) == str:
        to_translate = [to_translate]
    for potential_abbreviation in to_translate:
        if type(potential_abbreviation) == str and not re.findall(r"([^A-Z]+)",potential_abbreviation):
            
            # First check the official abrev. than the self created ones e.g. VAE for the Emirates
            for column in ["wiki_abbreviations","inoff_abbreviations"]:
                for i, abbreviation in enumerate(wikipedia_country_list[column]):
                    if potential_abbreviation in abbreviation:
                        to_return.append(wikipedia_country_list["state_name_de"].tolist()[i])
        elif type(potential_abbreviation) == list:
            list_entry = [translate_abbreviation(nested_entry) for nested_entry in potential_abbreviation]
            flattened = [entry for sublist in list_entry for entry in sublist]
            to_return.append(flattened)
        else:
            to_return.append(potential_abbreviation)
    return to_return

In [537]:
# Test for translate_abbreviation()
example_to_abbreviate = ["USA","VAE",'Italien', "DR Cong",["Deutschland", "EU"],["Belgien","DRC"]]
desired_output = ['Vereinigte Staaten','Vereinigte Arabische Emirate','Italien','DR Cong',
                  ['Deutschland', 'Europäische Union'],
                  ['Belgien', 'Kongo, Demokratische Republik']]
if deep_eq(translate_abbreviation(example_to_abbreviate),desired_output):
    print("Test succesful")
else:
    Print("Test failed")

Test succesful


In [None]:
## SIMPLE TRANSLATION. FAST BUT DOES NOT TRANSLATE LISTS OF LISTS
# # Translate German entries of Ereignisdatenbank to English. Might be inefficient since I go through the wiki list
# # entirely which is longer then the list of countries to translate
# translated_ereignisdatenbank_countries = [(entry,wikipedia_country_list["translation_state_name"].tolist()[indx])\
#                                           for indx,entry \
#                                           in enumerate(wikipedia_country_list["state_name_de"].tolist())\
#                                           if entry in countries_unique]

In [538]:
from didyoumean import didyoumean

def translate(countries_unique):
    """Translate German entries of Ereignisdatenbank to English an returns tuple 
    of German word and English translation(s if ambigious)""" 

    continents = ["europa","africa","america","australien","asia"]
    translated_ereignisdatenbank_countries = []
    state_name_de = wikipedia_country_list["state_name_de"].tolist()
    full_state_name_de = wikipedia_country_list["full_state_name_de"].tolist()
    translation = wikipedia_country_list["translation_state_name"].tolist()
    
    if type(countries_unique) == str:
        countries_unique = [countries_unique]
    
    for entry in countries_unique:
        if type(entry) == str:
            sucessfull_search = list(filter(lambda x: re.findall(entry,x),state_name_de))
            if sucessfull_search:
                found = [translation[state_name_de.index(entry)] for entry in sucessfull_search]
                if len(found) == 1:
                    translated_ereignisdatenbank_countries.append((entry,found[0]))
                else:
                    
                    # Check for idendity in not ambigious case otherwise e.g Niger --> (Niger, Nigeria)
                    identical = [found_ent for found_ent in found if entry == found_ent]
                    if identical:
                        translated_ereignisdatenbank_countries.append((entry,identical[0]))
                    else:
                        translated_ereignisdatenbank_countries.append((entry,found))
            else:
                
                # If entry not in state_name_de, search in full_state_name_de
                sucessfull_search = list(filter(lambda x: re.findall(entry,x),full_state_name_de))
                if sucessfull_search:
                    found = [translation[full_state_name_de.index(entry)] for entry in sucessfull_search]
                    if len(found) == 1:
                        translated_ereignisdatenbank_countries.append((entry,found[0]))
                    else:
                        translated_ereignisdatenbank_countries.append((entry,found))
                else:
                    sucessfull_search_en = list(filter(lambda x: re.findall(entry,x),translation))
                    if sucessfull_search_en:
                        found = [state_name_de[translation.index(entry)] for entry in sucessfull_search_en]
                        if len(found) == 1:
                            translated_ereignisdatenbank_countries.append((found[0],entry))
                        else:
                            translated_ereignisdatenbank_countries.append((found,entry))
                    else:
                        
                        # If there was not match at all, check for spelling mistakes
                        did_u_mean = didyoumean.didYouMean(entry,state_name_de)
                        
                        """Exlude words with continent names since there are countries with a continent name
                        but there are also entries in the Ereignisdatenbank that mean the whole country. They
                        are not matched (e.g. Nordafrika) and there is must not be a match, otherwise 
                        didYouMean would falsly return Südafrika.
                        """
                        if did_u_mean and (did_u_mean not in continents):
                            translated_ereignisdatenbank_countries.append(translate(did_u_mean))
                        else:
                            translated_ereignisdatenbank_countries.append(entry)
                                                            
        elif type(entry) == list:
            translated_ereignisdatenbank_countries.append(translate(entry))
        else:
            translated_ereignisdatenbank_countries.append(entry)
    return translated_ereignisdatenbank_countries

In [540]:
example_to_translate = ["Deutschland","Delaware",["Kongo","China"],"Niger"]
expected_result_translate = [('Deutschland', 'Germany'),
                             'Delaware',
                             [('Kongo',
                               ['Congo, Democratic Republic of the (Kinshasa)','Congo, Republic of (Brazzaville)']),
                              ('China','China')],
                            ('Niger', 'Niger')]
if deep_eq(translate(example_to_translate),expected_result_translate):
    print("Test succesfull")
else:
    print("Test failed")

Test succesfull


In [541]:
translated = translate(translate_abbreviation(clean_entries(countries)))
translated

[('Algerien', 'Algeria'),
 ('Peru', 'Peru'),
 ('Äthiopien', 'Ethiopia'),
 [('Ungarn', 'Hungary'),
  ('Serbien', 'Serbia'),
  ('Rumänien', 'Romania'),
  ('Italien', 'Italy'),
  ('Griechenland', 'Greece')],
 ('Sri Lanka', 'Sri Lanka'),
 ('El Salvador', 'El Salvador'),
 ('Jemen', 'Yemen'),
 ('Namibia', 'Namibia'),
 ('Simbabwe', 'Zimbabwe'),
 ('Kongo, Demokratische Republik',
  'Congo, Democratic Republic of the (Kinshasa)'),
 ('Angola', 'Angola'),
 ('Burundi', 'Burundi'),
 ('Vereinigte Staaten', 'United States'),
 ('Kolumbien', 'Colombia'),
 ('Bolivien', 'Bolivia'),
 ('Schweiz', 'Switzerland'),
 ('Südsudan', 'South Sudan'),
 ('Haiti', 'Haiti'),
 ('Venezuela', 'Venezuela'),
 ('Niger', 'Niger'),
 ('Ghana', 'Ghana'),
 ('Costa Rica', 'Costa Rica'),
 ('China', 'China'),
 ('Bangladesch', 'Bangladesh'),
 ('Somalia', 'Somalia'),
 ('Syrien', 'Syria'),
 'Französich Polynesien',
 ('Kanada', 'Canada'),
 ('Tansania', 'Tanzania'),
 ('Nigeria', 'Nigeria'),
 ('Taiwan', 'Taiwan oder Republic of China'),
 

In [542]:
flattened = [entry if type(sublist) == list else sublist for sublist in translated for entry in sublist ]
countries_not_translated = set([entry for entry in flattened if type(entry) == str])
countries_not_translated

{'DR Congo',
 'DRCongo',
 'Delaware',
 'Französich Polynesien',
 'Französiche Guyana',
 'Französisch-Polynesien',
 'La Reunion',
 'Nordeuropa',
 'Typhus',
 'VAE Dubai'}

# Unterstanding geoname annotator

In [None]:
"""Geoname Annotator"""

from geopy.distance import great_circle
from .maximum_weight_interval_set import Interval, find_maximum_weight_interval_set

# Containment levels indicate which properties must match when determing
# whether a geoname of a given containment level contains another geoname.
# The admin codes generally correspond to states, provinces and cities.
CONTAINMENT_LEVELS = [
    'country_code',
    'admin1_code',
    'admin2_code',
    'admin3_code',
    'admin4_code'
]

GEONAME_ATTRS = [
    'geonameid',
    'name',
    'feature_code',
    'country_code',
    'admin1_code',
    'admin2_code',
    'admin3_code',
    'admin4_code',
    'longitude',
    'latitude',
    'population',
    'asciiname',
    'names_used',
    'name_count']
def location_contains(loc_outer, loc_inner):
    """
    Do a comparison to see if the first geoname contains the second.
    It returns an integer to indicate the level of containment.
    0 indicates no containment. Siblings locations and identical locations
    have 0 containment. The level of containment is determined by the specificty
    of the outer location. e.g. USA would be a smaller number than Texas.
    In order for containment to be detected the outer location must have a
    ADM* or PCL* feature code, which is most countries, states, and districts.
    """
    # Test the country code in advance for efficiency. The country code must match for
    # any level of containment.
    if loc_outer.country_code != loc_inner.country_code or loc_outer.country_code == '':
        return 0
    feature_code = loc_outer.feature_code
    if feature_code == 'ADM1':
        outer_feature_level = 2
    elif feature_code == 'ADM2':
        outer_feature_level = 3
    elif feature_code == 'ADM3':
        outer_feature_level = 4
    elif feature_code == 'ADM4':
        outer_feature_level = 5
    elif re.match("^PCL.", feature_code):
        outer_feature_level = 1
    else:
        return 0
    for prop in CONTAINMENT_LEVELS[1:outer_feature_level]:
        if loc_outer[prop] == '':
            return 0
        if loc_outer[prop] != loc_inner[prop]:
            return 0
    if loc_outer.geonameid == loc_inner.geonameid:
        return 0
    return outer_feature_level



GEONAME_ATTRS = [
    'geonameid',
    'name',
    'feature_code',
    'country_code',
    'admin1_code',
    'admin2_code',
    'admin3_code',
    'admin4_code',
    'longitude',
    'latitude',
    'population',
    'asciiname',
    'names_used',
    'name_count']


ADMINNAME_ATTRS = [
    'country_name',
    'admin1_name',
    'admin2_name',
    'admin3_name']


class GeonameRow(object):
    __slots__ = GEONAME_ATTRS + ADMINNAME_ATTRS + [
        'alternate_locations',
        'spans',
        'parents',
        'score',
        'lat_long',
        'high_confidence']

    def __init__(self, sqlite3_row):
        for key in sqlite3_row.keys():
            if key in GEONAME_ATTRS:
                setattr(self, key, sqlite3_row[key])
        self.lat_long = (self.latitude, self.longitude,)
        self.alternate_locations = set()
        self.spans = set()
        self.parents = set()
        self.score = None



    def to_dict(self):
        result = {}
        for key in GEONAME_ATTRS:
            result[key] = self[key]
        for key in ADMINNAME_ATTRS:
            if hasattr(self, key):
                result[key] = self[key]
        result['parents'] = [p.to_dict() for p in self.parents]
        result['score'] = self.score
        return result


class GeonameFeatures(object):
    """
    This represents the aspects of a condidate geoname that are used to
    determine whether it is being referenced.
    """
    # The feature name array is used to maintain the order of the
    # values in the feature vector.
    feature_names = [
        'log_population',
        'name_count',
        'num_spans',
        'max_span_length',
        'cannonical_name_used',
        'loc_NE_portion',
        'other_NE_portion',
        'noun_portion',
        'other_pos_portion',
        'num_tokens',
        'ambiguity',
        'PPL_feature_code',
        'ADM_feature_code',
        'CONT_feature_code',
        'other_feature_code',
        'combined_span_parents',
        'close_locations',
        'very_close_locations',
        'containing_locations',
        'max_containment_level',
        # high_confidence indicates the base feature set received a high score.
        # It is an useful feature for preventing high confidence geonames
        # from receiving low final scores when they lack contextual cues -
        # for example, when they are the only location mentioned.
        'high_confidence',
    ]

    def __init__(self, geoname, spans_to_nes, span_to_tokens):
        self.geoname = geoname
        # The set of geonames that are mentioned in proximity to the spans
        # corresponding to this feature.
        # This will be populated by the add_contextual_features function.
        self.nearby_mentions = set()
        d = {}
        d['log_population'] = math.log(geoname.population + 1)
        # Geonames with lots of alternate names
        # tend to be the ones most commonly referred to.
        d['name_count'] = geoname.name_count
        d['num_spans'] = len(geoname.spans)
        d['max_span_length'] = max([
            len(span.text) for span in geoname.spans])

        def cannonical_name_match(span, geoname):
            first_leaf = next(span.iterate_leaf_base_spans(), None)
            if first_leaf:
                span_text = first_leaf.text
            else:
                span_text = span.text
            span_in_name = span_text in geoname.name or span_text in geoname.asciiname
            return (float(len(span_text)) if span_in_name else 0) / len(geoname.name)
        d['cannonical_name_used'] = max([
            cannonical_name_match(span, geoname)
            for span in geoname.spans
        ])
        loc_NEs_overlap = 0
        other_NEs_overlap = 0
        total_spans = len(geoname.spans)
        for span in geoname.spans:
            for ne_span in spans_to_nes[span]:
                if ne_span.label == 'GPE' or ne_span.label == 'LOC':
                    loc_NEs_overlap += 1
                else:
                    other_NEs_overlap += 1
        d['loc_NE_portion'] = float(loc_NEs_overlap) / total_spans
        d['other_NE_portion'] = float(other_NEs_overlap) / total_spans
        noun_pos_tags = 0
        other_pos_tags = 0
        pos_tags = 0
        for span in geoname.spans:
            for token_span in span_to_tokens[span]:
                token = token_span.token
                pos_tags += 1
                if token.tag_.startswith("NN") or token.tag_ == "FW":
                    noun_pos_tags += 1
                else:
                    other_pos_tags += 1
        d['combined_span_parents'] = len(geoname.parents)
        d['noun_portion'] = float(noun_pos_tags) / pos_tags
        d['other_pos_portion'] = float(other_pos_tags) / pos_tags
        d['num_tokens'] = pos_tags
        d['ambiguity'] = len(geoname.alternate_locations)
        feature_code = geoname.feature_code
        if feature_code.startswith('PPL'):
            d['PPL_feature_code'] = 1
        elif feature_code.startswith('ADM'):
            d['ADM_feature_code'] = 1
        elif feature_code.startswith('CONT'):
            d['CONT_feature_code'] = 1
        else:
            d['other_feature_code'] = 1
        self._values = [0] * len(self.feature_names)
        self.set_values(d)

    def set_value(self, feature_name, value):
        self._values[self.feature_names.index(feature_name)] = value

    def set_values(self, value_dict):
        for idx, name in enumerate(self.feature_names):
            if name in value_dict:
                self._values[idx] = value_dict[name]

    def set_contextual_features(self):
        """
        GeonameFeatures are initialized with only values that can be extracted
        from the geoname database and span. This extends the GeonameFeature
        with values that require information from nearby_mentions.
        """
        geoname = self.geoname
        close_locations = 0
        very_close_locations = 0
        containing_locations = 0
        max_containment_level = 0
        for recently_mentioned_geoname in self.nearby_mentions:
            if recently_mentioned_geoname == geoname:
                continue
            containment_level = max(
                location_contains(geoname, recently_mentioned_geoname),
                location_contains(recently_mentioned_geoname, geoname))
            if containment_level > 0:
                containing_locations += 1
            if containment_level > max_containment_level:
                max_containment_level = containment_level
            distance = great_circle(
                recently_mentioned_geoname.lat_long, geoname.lat_long
            ).kilometers
            if distance < 400:
                close_locations += 1
            if distance < 100:
                very_close_locations += 1
        self.set_values(dict(
            close_locations=close_locations,
            very_close_locations=very_close_locations,
            containing_locations=containing_locations,
            max_containment_level=max_containment_level))


class GeonameAnnotator(Annotator):
    def __init__(self, custom_classifier=None):
        self.connection = get_database_connection()
        self.connection.row_factory = sqlite3.Row
        if custom_classifier:
            self.geoname_classifier = custom_classifier
        else:
            self.geoname_classifier = geoname_classifier

    def get_candidate_geonames(self, doc):
        """
        Returns an array of geoname dicts correponding to locations that the
        document may refer to.
        The dicts are extended with lists of associated AnnoSpans.
        """
        if 'ngrams' not in doc.tiers:
            doc.add_tiers(NgramAnnotator())
        logger.info('Ngrams annotated')
        if 'nes' not in doc.tiers:
            doc.add_tiers(NEAnnotator())
        logger.info('Named entities annotated')

        all_ngrams = list(set([span.text.lower()
                               for span in doc.tiers['ngrams'].spans
                               if is_possible_geoname(span.text)
                               ]))
        cursor = self.connection.cursor()
        geoname_results = list(cursor.execute('''
        SELECT
            geonames.*,
            count AS name_count,
            group_concat(alternatename, ";") AS names_used
        FROM geonames
        JOIN alternatename_counts USING ( geonameid )
        JOIN alternatenames USING ( geonameid )
        WHERE alternatename_lemmatized IN
        (''' + ','.join('?' for x in all_ngrams) + ''')
        GROUP BY geonameid''', all_ngrams))
        logger.info('%s geonames fetched' % len(geoname_results))
        geoname_results = [GeonameRow(g) for g in geoname_results]
        # Associate spans with the geonames.
        # This is done up front so span information can be used in the scoring
        # function
        span_text_to_spans = defaultdict(list)
        for span in doc.tiers['ngrams'].spans:
            if is_possible_geoname(span.text):
                span_text_to_spans[span.text.lower()].append(span)
        candidate_geonames = []
        for geoname in geoname_results:
            geoname.add_spans(span_text_to_spans)
            # In rare cases geonames may have no matching spans because
            # sqlite unicode equivalency rules match geonames that use different
            # characters the document spans used to query them.
            # These geonames are ignored.
            if len(geoname.spans) > 0:
                candidate_geonames.append(geoname)
        # Add combined spans to locations that are adjacent to a span linked to
        # an administrative division. e.g. Seattle, WA
        span_to_geonames = defaultdict(list)
        for geoname in candidate_geonames:
            for span in geoname.spans:
                span_to_geonames[span].append(geoname)
        geoname_spans = span_to_geonames.keys()
        combined_spans = AnnoTier(geoname_spans).chains(at_least=2, at_most=4, max_dist=4)
        for combined_span in combined_spans:
            leaf_spans = combined_span.iterate_leaf_base_spans()
            first_spans = next(leaf_spans)
            potential_geonames = {geoname: set()
                                  for geoname in span_to_geonames[first_spans]}
            for leaf_span in leaf_spans:
                leaf_span_geonames = span_to_geonames[leaf_span]
                next_potential_geonames = defaultdict(set)
                for potential_geoname, prev_containing_geonames in potential_geonames.items():
                    containing_geonames = [
                        containing_geoname
                        for containing_geoname in leaf_span_geonames
                        if location_contains(containing_geoname, potential_geoname) > 0]
                    if len(containing_geonames) > 0:
                        next_potential_geonames[potential_geoname] |= prev_containing_geonames | set(containing_geonames)
                potential_geonames = next_potential_geonames
            for geoname, containing_geonames in potential_geonames.items():
                geoname.spans.add(combined_span)
                geoname.parents |= containing_geonames
        # Replace individual spans with combined spans.
        span_to_geonames = defaultdict(list)
        for geoname in candidate_geonames:
            geoname.spans = set(AnnoTier(geoname.spans).optimal_span_set().spans)
            for span in geoname.spans:
                span_to_geonames[span].append(geoname)
        # Find locations with overlapping spans
        # Note that is is possible for two valid locations to have
        # overlapping names. For example, Harare Province has
        # Harare as an alternate name, so the city Harare is very
        # likely to be an alternate location that competes with it.
        for span, geonames in span_to_geonames.items():
            geoname_set = set(geonames)
            for geoname in geonames:
                geoname.alternate_locations |= geoname_set
        for geoname in candidate_geonames:
            geoname.alternate_locations -= set([geoname])
        logger.info('%s alternative locations found' % sum([
            len(geoname.alternate_locations)
            for geoname in candidate_geonames]))
        logger.info('%s candidate locations prepared' %
                    len(candidate_geonames))
        return candidate_geonames

   

    def add_contextual_features(self, features):
        """
        Extend a list of features with values that are based on the geonames
        mentioned nearby.
        """
        logger.info('adding contextual features')
        span_to_features = defaultdict(list)
        for feature in features:
            for span in feature.geoname.spans:
                span_to_features[span].append(feature)
        geoname_span_tier = AnnoTier(list(span_to_features.keys()))



    def annotate(self, doc):
        logger.info('geoannotator started')
        candidate_geonames = self.get_candidate_geonames(doc)
        features = self.extract_features(candidate_geonames, doc)
        if len(features) == 0:
            doc.tiers['geonames'] = AnnoTier([])
            return doc

        scores = self.geoname_classifier.predict_proba_base([
            list(f.values()) for f in features])
        for geoname, feature, score in zip(candidate_geonames, features, scores):
            geoname.high_confidence = float(
                score[1]) > self.geoname_classifier.HIGH_CONFIDENCE_THRESHOLD
            feature.set_value('high_confidence', geoname.high_confidence)
        has_high_confidence_features = any(
            [geoname.high_confidence for geoname in candidate_geonames])
        if has_high_confidence_features:
            self.add_contextual_features(features)
            scores = self.geoname_classifier.predict_proba_contextual([
                list(f.values()) for f in features])
        for geoname, score in zip(candidate_geonames, scores):
            geoname.score = float(score[1])
        culled_geonames = [geoname
                           for geoname in candidate_geonames
                           if geoname.score > self.geoname_classifier.GEONAME_SCORE_THRESHOLD]
        cursor = self.connection.cursor()
        for geoname in culled_geonames:
            geoname_results = list(cursor.execute('''
                SELECT
                    cc.name,
                    a1.name,
                    a2.name,
                    a3.name
                FROM adminnames a3
                JOIN adminnames a2 ON (
                    a2.country_code = a3.country_code AND
                    a2.admin1_code = a3.admin1_code AND
                    a2.admin2_code = a3.admin2_code AND
                    a2.admin3_code = "" )
                JOIN adminnames a1 ON (
                    a1.country_code = a3.country_code AND
                    a1.admin1_code = a3.admin1_code AND
                    a1.admin2_code = "" AND
                    a1.admin3_code = "" )
                JOIN adminnames cc ON (
                    cc.country_code = a3.country_code AND
                    cc.admin1_code = "00" AND
                    cc.admin2_code = "" AND
                    cc.admin3_code = "" )
                WHERE (a3.country_code = ? AND a3.admin1_code = ? AND a3.admin2_code = ? AND a3.admin3_code = ?)
                ''', (
                geoname.country_code or "",
                geoname.admin1_code or "",
                geoname.admin2_code or "",
                geoname.admin3_code or "",)))
            for result in geoname_results:
                prev_val = None
                for idx, attr in enumerate(['country_name', 'admin1_name', 'admin2_name', 'admin3_name']):
                    val = result[idx]
                    if val == prev_val:
                        # Names are repeated for admin levels beyond that of
                        # the geoname.
                        break
                    setattr(geoname, attr, val)
                    prev_val = val
        logger.info('admin names added')
        geo_spans = []
        for geoname in culled_geonames:
            for span in geoname.spans:
                geo_span = GeoSpan(
                    span.start, span.end, doc, geoname)
                geo_spans.append(geo_span)
        culled_geospans = AnnoTier(geo_spans).optimal_span_set(prefer=lambda x: (x.size(), x.geoname.score,))
        logger.info('overlapping geospans removed')
        return {'geonames': culled_geospans}

In [None]:
==abE