<a href="https://colab.research.google.com/github/KarolChlasta/address-net/blob/master/AddressNet_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [100]:
pip list

Package                          Version
-------------------------------- ---------------------
absl-py                          1.4.0
address-net                      1.0
aiohttp                          3.8.6
aiosignal                        1.3.1
alabaster                        0.7.13
albumentations                   1.3.1
altair                           4.2.2
anyio                            3.7.1
appdirs                          1.4.4
argon2-cffi                      23.1.0
argon2-cffi-bindings             21.2.0
array-record                     0.5.0
arviz                            0.15.1
astor                            0.8.1
astropy                          5.3.4
astunparse                       1.6.3
async-timeout                    4.0.3
atpublic                         4.0
attrs                            23.1.0
audioread                        3.0.1
autograd                         1.6.2
Babel                            2.13.1
backcall                         0.2.0
beaut

In [101]:
pip install git+https://github.com/jasonrig/address-net.git

Collecting git+https://github.com/jasonrig/address-net.git
  Cloning https://github.com/jasonrig/address-net.git to /tmp/pip-req-build-8csookdd
  Running command git clone --filter=blob:none --quiet https://github.com/jasonrig/address-net.git /tmp/pip-req-build-8csookdd
  Resolved https://github.com/jasonrig/address-net.git to commit 28e7c2de030bae56f81c66d7e640dcc2d04fdfb6
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [102]:
#pip install tensorflow-gpu==1.12.0
#pip install -q condacolab
#condacolab search python
#!curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh

!python -m pip install --upgrade https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.12.0-py3-none-any.whl

Collecting tensorflow==1.12.0
  Using cached https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.12.0-py3-none-any.whl (62.0 MB)


In [103]:
from typing import Optional, Union, Callable, List
from collections import OrderedDict

import random
import tensorflow as tf
import numpy as np
import string

import addressnet.lookups as lookups
from addressnet.typo import generate_typo

# Schema used to decode data from the TFRecord file
_features = OrderedDict([
    ('building_name', tf.io.FixedLenFeature([], tf.string)),
    ('lot_number_prefix', tf.io.FixedLenFeature([], tf.string)),
    ('lot_number', tf.io.FixedLenFeature([], tf.string)),
    ('lot_number_suffix', tf.io.FixedLenFeature([], tf.string)),
    ('flat_number_prefix', tf.io.FixedLenFeature([], tf.string)),
    ('flat_number_suffix', tf.io.FixedLenFeature([], tf.string)),
    ('level_number_prefix', tf.io.FixedLenFeature([], tf.string)),
    ('level_number_suffix', tf.io.FixedLenFeature([], tf.string)),
    ('number_first_prefix', tf.io.FixedLenFeature([], tf.string)),
    ('number_first_suffix', tf.io.FixedLenFeature([], tf.string)),
    ('number_last_prefix', tf.io.FixedLenFeature([], tf.string)),
    ('number_last_suffix', tf.io.FixedLenFeature([], tf.string)),
    ('street_name', tf.io.FixedLenFeature([], tf.string)),
    ('locality_name', tf.io.FixedLenFeature([], tf.string)),
    ('postcode', tf.io.FixedLenFeature([], tf.string)),
    ('flat_number', tf.io.FixedLenFeature([], tf.int64)),
    ('level_number', tf.io.FixedLenFeature([], tf.int64)),
    ('number_first', tf.io.FixedLenFeature([], tf.int64)),
    ('number_last', tf.io.FixedLenFeature([], tf.int64)),
    ('flat_type', tf.io.FixedLenFeature([], tf.int64)),
    ('level_type', tf.io.FixedLenFeature([], tf.int64)),
    ('street_type_code', tf.io.FixedLenFeature([], tf.int64)),
    ('street_suffix_code', tf.io.FixedLenFeature([], tf.int64)),
    ('state_abbreviation', tf.io.FixedLenFeature([], tf.int64)),
    ('latitude', tf.io.FixedLenFeature([], tf.float32)),
    ('longitude', tf.io.FixedLenFeature([], tf.float32))
])

# List of fields used as labels in the training data
labels_list = [
    'building_name',  # 1
    'level_number_prefix',  # 2
    'level_number',  # 3
    'level_number_suffix',  # 4
    'level_type',  # 5
    'flat_number_prefix',  # 6
    'flat_number',  # 7
    'flat_number_suffix',  # 8
    'flat_type',  # 9
    'number_first_prefix',  # 10
    'number_first',  # 11
    'number_first_suffix',  # 12
    'number_last_prefix',  # 13
    'number_last',  # 14
    'number_last_suffix',  # 15
    'street_name',  # 16
    'street_suffix_code',  # 17
    'street_type_code',  # 18
    'locality_name',  # 19
    'state_abbreviation',  # 20
    'postcode'  # 21
]
# Number of labels in total (+1 for the blank category)
n_labels = len(labels_list) + 1

# Allowable characters for the encoded representation
vocab = list(string.digits + string.ascii_lowercase + string.punctuation + string.whitespace)


def vocab_lookup(characters: str) -> (int, np.ndarray):
    """
    Converts a string into a list of vocab indices
    :param characters: the string to convert
    :param training: if True, artificial typos will be introduced
    :return: the string length and an array of vocab indices
    """
    result = list()
    for c in characters.lower():
        try:
            result.append(vocab.index(c) + 1)
        except ValueError:
            result.append(0)
    return len(characters), np.array(result, dtype=np.int64)


def decode_data(record: List[Union[str, int, float]]) -> Union[str, int, float]:
    """
    Decodes a record from the tfrecord file by converting all strings to UTF-8 encoding, and any numeric field with
    a value of -1 to None.
    :param record: the record to decode
    :return: an iterator for yielding the decoded fields
    """
    for item in record:
        try:
            # Attempt to treat the item in the record as a string
            yield item.decode("UTF-8")
        except AttributeError:
            # Treat the item as a number and encode -1 as None (see generate_tf_records.py)
            yield item if item != -1 else None


def labels(text: Union[str, int], field_name: Optional[str], mutate: bool = True) -> (str, np.ndarray):
    """
    Generates a numpy matrix labelling each character by field type. Strings have artificial typos introduced if
    mutate == True
    :param text: the text to label
    :param field_name: the name of the field to which the text belongs, or None if the label is blank
    :param mutate: introduce artificial typos
    :return: the original text and the numpy matrix of labels
    """

    # Ensure the input is a string, encoding None to an empty to string
    if text is None:
        text = ''
    else:
        # Introduce artificial typos if mutate == True
        text = generate_typo(str(text)) if mutate else str(text)
    labels_matrix = np.zeros((len(text), n_labels), dtype=np.bool)

    # If no field is supplied, then encode the label using the blank category
    if field_name is None:
        labels_matrix[:, 0] = True
    else:
        labels_matrix[:, labels_list.index(field_name) + 1] = True
    return text, labels_matrix


def random_separator(min_length: int = 1, max_length: int = 3, possible_sep_chars: Optional[str] = r",./\  ") -> str:
    """
    Generates a space-padded separator of random length using a random character from possible_sep_chars
    :param min_length: minimum length of the separator
    :param max_length: maximum length of the separator
    :param possible_sep_chars: string of possible characters to use for the separator
    :return: the separator string
    """
    chars = [" "] * random.randint(min_length, max_length)
    if len(chars) > 0 and possible_sep_chars:
        sep_char = random.choice(possible_sep_chars)
        chars[random.randrange(len(chars))] = sep_char
    return ''.join(chars)


def join_labels(lbls: [np.ndarray], sep: Union[str, Callable[..., str]] = " ") -> np.ndarray:
    """
    Concatenates a series of label matrices with a separator
    :param lbls: a list of numpy matrices
    :param sep: the separator string or function that returns the sep string
    :return: the concatenated labels
    """
    if len(lbls) < 2:
        return lbls

    joined_labels = None
    sep_str = None

    # if `sep` is not a function, set the separator (`sep_str`) to `sep`, otherwise leave as None
    if not callable(sep):
        sep_str = sep

    for l in lbls:
        if joined_labels is None:
            joined_labels = l
        else:
            # If `sep` is a function, call it on each iteration
            if callable(sep):
                sep_str = sep()

            # Skip zero-length labels
            if l.shape[0] == 0:
                continue
            elif sep_str is not None and len(sep_str) > 0 and joined_labels.shape[0] > 0:
                # Join using sep_str if it's present and non-zero in length
                joined_labels = np.concatenate([joined_labels, labels(sep_str, None, mutate=False)[1], l], axis=0)
            else:
                # Otherwise, directly concatenate the labels
                joined_labels = np.concatenate([joined_labels, l], axis=0)

    assert joined_labels is not None, "No labels were joined!"
    assert joined_labels.shape[1] == n_labels, "The number of labels generated was unexpected: got %i but wanted %i" % (
        joined_labels.shape[1], n_labels)

    return joined_labels


def join_str_and_labels(parts: [(str, np.ndarray)], sep: Union[str, Callable[..., str]] = " ") -> (str, np.ndarray):
    """
    Joins the strings and labels using the given separator
    :param parts: a list of string/label tuples
    :param sep: a string or function that returns the string to be used as a separator
    :return: the joined string and labels
    """
    # Keep only the parts with strings of length > 0
    parts = [p for p in parts if len(p[0]) > 0]

    # If there are no parts at all, return an empty string an array of shape (0, n_labels)
    if len(parts) == 0:
        return '', np.zeros((0, n_labels))
    # If there's only one part, just give it back as-is
    elif len(parts) == 1:
        return parts[0]

    # Pre-generate the separators - this is important if `sep` is a function returning non-deterministic results
    n_sep = len(parts) - 1
    if callable(sep):
        seps = [sep() for _ in range(n_sep)]
    else:
        seps = [sep] * n_sep
    seps += ['']

    # Join the strings using the list of separators
    strings = ''.join(sum([(s[0][0], s[1]) for s in zip(parts, seps)], ()))

    # Join the labels using an iterator function
    sep_iter = iter(seps)
    lbls = join_labels([s[1] for s in parts], sep=lambda: next(sep_iter))

    assert len(strings) == lbls.shape[0], "string length %i (%s), label length %i using sep %s" % (
        len(strings), strings, lbls.shape[0], seps)
    return strings, lbls


def choose(option1: Callable = lambda: None, option2: Callable = lambda: None):
    """
    Randomly run either option 1 or option 2
    :param option1: a possible function to run
    :param option2: another possible function to run
    :return: the result of the function
    """
    if random.getrandbits(1):
        return option1()
    else:
        return option2()


def synthesise_address(*record) -> (int, np.ndarray, np.ndarray):
    """
    Uses the record information to construct a formatted address with labels. The addresses generated involve
    semi-random permutations and corruptions to help avoid over-fitting.
    :param record: the decoded item from the TFRecord file
    :return: the address string length, encoded text and labels
    """
    fields = dict(zip(_features.keys(), decode_data(record)))

    # Generate the individual address components:
    if fields['level_type'] > 0:
        level = generate_level_number(fields['level_type'], fields['level_number_prefix'], fields['level_number'],
                                      fields['level_number_suffix'])
    else:
        level = ('', np.zeros((0, n_labels)))

    if fields['flat_type'] > 0:
        flat_number = generate_flat_number(
            fields['flat_type'], fields['flat_number_prefix'], fields['flat_number'], fields['flat_number_suffix'])
    else:
        flat_number = ('', np.zeros((0, n_labels)))

    street_number = generate_street_number(fields['number_first_prefix'], fields['number_first'],
                                           fields['number_first_suffix'], fields['number_last_prefix'],
                                           fields['number_last'], fields['number_last_suffix'])
    street = generate_street_name(fields['street_name'], fields['street_suffix_code'], fields['street_type_code'])
    suburb = labels(fields['locality_name'], 'locality_name')
    state = generate_state(fields['state_abbreviation'])
    postcode = labels(fields['postcode'], 'postcode')
    building_name = labels(fields['building_name'], 'building_name')

    # Begin composing the formatted address, building up the `parts` variable...

    suburb_state_postcode = list()
    # Keep the suburb?
    choose(lambda: suburb_state_postcode.append(suburb))
    # Keep state?
    choose(lambda: suburb_state_postcode.append(state))
    # Keep postcode?
    choose(lambda: suburb_state_postcode.append(postcode))

    random.shuffle(suburb_state_postcode)

    parts = [[building_name], [level]]

    # Keep the street number? (If street number is dropped, the flat number is also dropped)
    def keep_street_number():
        # force flat number to be next to street number only if the flat number is only digits (i.e. does not have a
        # flat type)
        if flat_number[0].isdigit():
            parts.append([flat_number, street_number, street])
        else:
            parts.append([flat_number])
            parts.append([street_number, street])

    choose(keep_street_number, lambda: parts.append([street]))

    random.shuffle(parts)

    # Suburb, state, postcode is always at the end of an address
    parts.append(suburb_state_postcode)

    # Flatten the address components into an unnested list
    parts = sum(parts, [])

    # Join each address component/label with a random separator
    address, address_lbl = join_str_and_labels(parts, sep=lambda: random_separator(1, 3))

    # Encode
    length, text_encoded = vocab_lookup(address)
    return length, text_encoded, address_lbl


def generate_state(state_abbreviation: int) -> (str, np.ndarray):
    """
    Generates the string and labels for the state, randomly abbreviated
    :param state_abbreviation: the state code
    :return: string and labels
    """
    state = lookups.lookup_state(state_abbreviation, reverse_lookup=True)
    return labels(choose(lambda: lookups.expand_state(state), lambda: state), 'state_abbreviation')


def generate_level_number(level_type: int, level_number_prefix: str, level_number: int, level_number_suffix: str) -> (
        str, np.ndarray):
    """
    Generates the level number for the address
    :param level_type: level type code
    :param level_number_prefix: number prefix
    :param level_number: level number
    :param level_number_suffix: level number suffix
    :return: string and labels
    """

    level_type = labels(lookups.lookup_level_type(level_type, reverse_lookup=True), 'level_type')

    # Decide whether to transform the level number
    def do_transformation():
        if not level_number_prefix and not level_number_suffix and level_type[0]:
            # If there is no prefix/suffix, decide whether to convert to ordinal numbers (1st, 2nd, etc.)
            def use_ordinal_numbers(lvl_num, lvl_type):
                # Use ordinal words (first, second, third) or numbers (1st, 2nd, 3rd)?
                lvl_num = choose(lambda: lookups.num2word(lvl_num, output='ordinal_words'),
                                 lambda: lookups.num2word(lvl_num, output='ordinal'))
                lvl_num = labels(lvl_num, 'level_number')
                return join_str_and_labels([lvl_num, lvl_type],
                                           sep=lambda: random_separator(1, 3, possible_sep_chars=None))

            def use_cardinal_numbers(lvl_num, lvl_type):
                # Treat level 1 as GROUND?
                if lvl_num == 1:
                    lvl_num = choose(lambda: "GROUND", lambda: 1)
                else:
                    lvl_num = lookups.num2word(lvl_num, output='cardinal')
                lvl_num = labels(lvl_num, 'level_number')
                return join_str_and_labels([lvl_type, lvl_num],
                                           sep=lambda: random_separator(1, 3, possible_sep_chars=None))

            return choose(lambda: use_ordinal_numbers(level_number, level_type),
                          lambda: use_cardinal_numbers(level_number, level_type))

    transformed_value = choose(do_transformation)
    if transformed_value:
        return transformed_value
    else:
        level_number_prefix = labels(level_number_prefix, 'level_number_prefix')
        level_number = labels(level_number, 'level_number')
        level_number_suffix = labels(level_number_suffix, 'level_number_suffix')
        return join_str_and_labels([level_type, level_number_prefix, level_number, level_number_suffix],
                                   sep=lambda: random_separator(1, 3, possible_sep_chars=None))


def generate_flat_number(
        flat_type: int, flat_number_prefix: str, flat_number: int, flat_number_suffix: str) -> (str, np.ndarray):
    """
    Generates the flat number for the address
    :param flat_type: flat type code
    :param flat_number_prefix: number prefix
    :param flat_number: number
    :param flat_number_suffix: number suffix
    :return: string and labels
    """
    flat_type = labels(lookups.lookup_flat_type(flat_type, reverse_lookup=True), 'flat_type')
    flat_number_prefix = labels(flat_number_prefix, 'flat_number_prefix')
    flat_number = labels(flat_number, 'flat_number')
    flat_number_suffix = labels(flat_number_suffix, 'flat_number_suffix')

    flat_number = join_str_and_labels([flat_number_prefix, flat_number, flat_number_suffix],
                                      sep=lambda: random_separator(0, 2, possible_sep_chars=None))

    return choose(
        lambda: join_str_and_labels([flat_type, flat_number], sep=random_separator(0, 2, possible_sep_chars=None)),
        lambda: flat_number)


def generate_street_number(number_first_prefix: str, number_first: int, number_first_suffix,
                           number_last_prefix, number_last, number_last_suffix) -> (str, np.ndarray):
    """
    Generates a street number using the prefix, suffix, first and last number components
    :param number_first_prefix: prefix to the first street number
    :param number_first: first street number
    :param number_first_suffix: suffix to the first street number
    :param number_last_prefix: prefix to the last street number
    :param number_last: last street number
    :param number_last_suffix: suffix to the last street number
    :return: the street number
    """

    number_first_prefix = labels(number_first_prefix, 'number_first_prefix')
    number_first = labels(number_first, 'number_first')
    number_first_suffix = labels(number_first_suffix, 'number_first_suffix')

    number_last_prefix = labels(number_last_prefix, 'number_last_prefix')
    number_last = labels(number_last, 'number_last')
    number_last_suffix = labels(number_last_suffix, 'number_last_suffix')

    a = join_str_and_labels([number_first_prefix, number_first, number_first_suffix],
                            lambda: random_separator(0, 2, possible_sep_chars=None))
    b = join_str_and_labels([number_last_prefix, number_last, number_last_suffix],
                            lambda: random_separator(0, 2, possible_sep_chars=None))

    return join_str_and_labels([a, b], sep=random_separator(1, 3, possible_sep_chars=r"----   \/"))


def generate_street_name(street_name: str, street_suffix_code: str, street_type_code: str) -> (str, np.ndarray):
    """
    Generates a possible street name variation
    :param street_name: the street's name
    :param street_suffix_code: the street suffix code
    :param street_type_code: the street type code
    :return: string and labels
    """
    street_name, street_name_lbl = labels(street_name, 'street_name')

    street_type = lookups.lookup_street_type(street_type_code, reverse_lookup=True)
    street_type = choose(lambda: lookups.abbreviate_street_type(street_type), lambda: street_type)
    street_type, street_type_lbl = labels(street_type, 'street_type_code')

    street_suffix = lookups.lookup_street_suffix(street_suffix_code, reverse_lookup=True)
    street_suffix = choose(lambda: lookups.expand_street_type_suffix(street_suffix), lambda: street_suffix)
    street_suffix, street_suffix_lbl = labels(street_suffix, 'street_suffix_code')

    return choose(lambda: join_str_and_labels([
        (street_name, street_name_lbl),
        (street_suffix, street_suffix_lbl),
        (street_type, street_type_lbl)
    ]), lambda: join_str_and_labels([
        (street_name, street_name_lbl),
        (street_type, street_type_lbl),
        (street_suffix, street_suffix_lbl)
    ]))


def dataset(filenames: [str], batch_size: int = 10, shuffle_buffer: int = 1000, prefetch_buffer_size: int = 10000,
            num_parallel_calls: int = 8) -> Callable:
    """
    Creates a Tensorflow dataset and iterator operations
    :param filenames: the tfrecord filenames
    :param batch_size: training batch size
    :param shuffle_buffer: shuffle buffer size
    :param prefetch_buffer_size: size of the prefetch buffer
    :param num_parallel_calls: number of parallel calls for the mapping functions
    :return: the input_fn
    """

    def input_fn() -> tf.data.Dataset:
        ds = tf.data.TFRecordDataset(filenames, compression_type="GZIP")
        ds = ds.shuffle(buffer_size=shuffle_buffer)
        ds = ds.map(lambda record: tf.parse_single_example(record, features=_features), num_parallel_calls=8)
        ds = ds.map(
            lambda record: tf.py_func(synthesise_address, [record[k] for k in _features.keys()],
                                      [tf.int64, tf.int64, tf.bool],
                                      stateful=False),
            num_parallel_calls=num_parallel_calls
        )

        ds = ds.padded_batch(batch_size, ([], [None], [None, n_labels]))

        ds = ds.map(
            lambda _lengths, _encoded_text, _labels: ({'lengths': _lengths, 'encoded_text': _encoded_text}, _labels),
            num_parallel_calls=num_parallel_calls
        )
        ds = ds.prefetch(buffer_size=prefetch_buffer_size)
        return ds

    return input_fn


def predict_input_fn(input_text: List[str]) -> Callable:
    """
    An input function for one prediction example
    :param input_text: the input text
    :return:
    """

    def input_fn() -> tf.data.Dataset:
        predict_ds = tf.data.Dataset.from_generator(
            lambda: (vocab_lookup(address) for address in input_text),
            (tf.int64, tf.int64),
            (tf.TensorShape([]), tf.TensorShape([None]))
        )
        predict_ds = predict_ds.batch(1)
        predict_ds = predict_ds.map(
            lambda lengths, encoded_text: {'lengths': lengths, 'encoded_text': encoded_text}
        )
        return predict_ds

    return input_fn

In [104]:
from collections import OrderedDict
from typing import Union

# Categorical types as per the GNAF dataset, see: https://data.gov.au/dataset/geocoded-national-address-file-g-naf
flat_types = ('ANTENNA', 'APARTMENT', 'AUTOMATED TELLER MACHINE', 'BARBECUE', 'BLOCK', 'BOATSHED', 'BUILDING',
              'BUNGALOW', 'CAGE', 'CARPARK', 'CARSPACE', 'CLUB', 'COOLROOM', 'COTTAGE', 'DUPLEX', 'FACTORY', 'FLAT',
              'GARAGE', 'HALL', 'HOUSE', 'KIOSK', 'LEASE', 'LOBBY', 'LOFT', 'LOT', 'MAISONETTE', 'MARINE BERTH',
              'OFFICE', 'PENTHOUSE', 'REAR', 'RESERVE', 'ROOM', 'SECTION', 'SHED', 'SHOP', 'SHOWROOM', 'SIGN', 'SITE',
              'STALL', 'STORE', 'STRATA UNIT', 'STUDIO', 'SUBSTATION', 'SUITE', 'TENANCY', 'TOWER', 'TOWNHOUSE',
              'UNIT', 'VAULT', 'VILLA', 'WARD', 'WAREHOUSE', 'WORKSHOP')

level_types = ('BASEMENT', 'FLOOR', 'GROUND', 'LEVEL', 'LOBBY', 'LOWER GROUND FLOOR', 'MEZZANINE', 'OBSERVATION DECK',
               'PARKING', 'PENTHOUSE', 'PLATFORM', 'PODIUM', 'ROOFTOP', 'SUB-BASEMENT', 'UPPER GROUND FLOOR')

street_types = ('ACCESS', 'ACRE', 'AIRWALK', 'ALLEY', 'ALLEYWAY', 'AMBLE', 'APPROACH', 'ARCADE', 'ARTERIAL', 'ARTERY',
                'AVENUE', 'BANAN', 'BANK', 'BAY', 'BEACH', 'BEND', 'BOARDWALK', 'BOULEVARD', 'BOULEVARDE', 'BOWL',
                'BRACE', 'BRAE', 'BRANCH', 'BREAK', 'BRETT', 'BRIDGE', 'BROADWALK', 'BROADWAY', 'BROW', 'BULL',
                'BUSWAY', 'BYPASS', 'BYWAY', 'CAUSEWAY', 'CENTRE', 'CENTREWAY', 'CHASE', 'CIRCLE', 'CIRCLET',
                'CIRCUIT', 'CIRCUS', 'CLOSE', 'CLUSTER', 'COLONNADE', 'COMMON', 'COMMONS', 'CONCORD', 'CONCOURSE',
                'CONNECTION', 'COPSE', 'CORNER', 'CORSO', 'COURSE', 'COURT', 'COURTYARD', 'COVE', 'CRESCENT', 'CREST',
                'CRIEF', 'CROOK', 'CROSS', 'CROSSING', 'CRUISEWAY', 'CUL-DE-SAC', 'CUT', 'CUTTING', 'DALE', 'DASH',
                'DELL', 'DENE', 'DEVIATION', 'DIP', 'DISTRIBUTOR', 'DIVIDE', 'DOCK', 'DOMAIN', 'DOWN', 'DOWNS',
                'DRIVE', 'DRIVEWAY', 'EASEMENT', 'EAST', 'EDGE', 'ELBOW', 'END', 'ENTRANCE', 'ESPLANADE', 'ESTATE',
                'EXPRESSWAY', 'EXTENSION', 'FAIRWAY', 'FIREBREAK', 'FIRELINE', 'FIRETRACK', 'FIRETRAIL', 'FLAT',
                'FLATS', 'FOLLOW', 'FOOTWAY', 'FORD', 'FORESHORE', 'FORK', 'FORMATION', 'FREEWAY', 'FRONT', 'FRONTAGE',
                'GAP', 'GARDEN', 'GARDENS', 'GATE', 'GATEWAY', 'GLADE', 'GLEN', 'GRANGE', 'GREEN', 'GROVE', 'GULLY',
                'HARBOUR', 'HAVEN', 'HEATH', 'HEIGHTS', 'HIGHROAD', 'HIGHWAY', 'HIKE', 'HILL', 'HILLS', 'HOLLOW',
                'HUB', 'INLET', 'INTERCHANGE', 'ISLAND', 'JUNCTION', 'KEY', 'KEYS', 'KNOLL', 'LADDER', 'LANDING',
                'LANE', 'LANEWAY', 'LEAD', 'LEADER', 'LINE', 'LINK', 'LOOKOUT', 'LOOP', 'LYNNE', 'MALL', 'MANOR',
                'MART', 'MAZE', 'MEAD', 'MEANDER', 'MEW', 'MEWS', 'MILE', 'MOTORWAY', 'NOOK', 'NORTH', 'NULL',
                'OUTLET', 'OUTLOOK', 'OVAL', 'PALMS', 'PARADE', 'PARADISE', 'PARK', 'PARKWAY', 'PART', 'PASS',
                'PASSAGE', 'PATH', 'PATHWAY', 'PENINSULA', 'PIAZZA', 'PLACE', 'PLAZA', 'POCKET', 'POINT', 'PORT',
                'PRECINCT', 'PROMENADE', 'PURSUIT', 'QUAD', 'QUADRANT', 'QUAY', 'QUAYS', 'RAMBLE', 'RAMP', 'RANGE',
                'REACH', 'REEF', 'RESERVE', 'REST', 'RETREAT', 'RETURN', 'RIDE', 'RIDGE', 'RIGHT OF WAY', 'RING',
                'RISE', 'RISING', 'RIVER', 'ROAD', 'ROADS', 'ROADWAY', 'ROTARY', 'ROUND', 'ROUTE', 'ROW', 'ROWE',
                'RUE', 'RUN', 'SERVICEWAY', 'SHUNT', 'SKYLINE', 'SLOPE', 'SOUTH', 'SPUR', 'SQUARE', 'STEPS',
                'STRAIGHT', 'STRAIT', 'STRAND', 'STREET', 'STRIP', 'SUBWAY', 'TARN', 'TERRACE', 'THOROUGHFARE',
                'THROUGHWAY', 'TOLLWAY', 'TOP', 'TOR', 'TRACK', 'TRAIL', 'TRAMWAY', 'TRAVERSE', 'TRIANGLE', 'TRUNKWAY',
                'TUNNEL', 'TURN', 'TWIST', 'UNDERPASS', 'VALE', 'VALLEY', 'VERGE', 'VIADUCT', 'VIEW', 'VIEWS', 'VILLA',
                'VILLAGE', 'VILLAS', 'VISTA', 'VUE', 'WADE', 'WALK', 'WALKWAY', 'WATERS', 'WATERWAY', 'WAY', 'WEST',
                'WHARF', 'WOOD', 'WOODS', 'WYND', 'YARD')

street_suffix_types = OrderedDict([('CN', 'CENTRAL'), ('DE', 'DEVIATION'), ('E', 'EAST'), ('EX', 'EXTENSION'),
                                   ('IN', 'INNER'), ('LR', 'LOWER'), ('ML', 'MALL'), ('N', 'NORTH'),
                                   ('NE', 'NORTH EAST'), ('NW', 'NORTH WEST'), ('OF', 'OFF'), ('ON', 'ON'),
                                   ('OT', 'OUTER'), ('OP', 'OVERPASS'), ('S', 'SOUTH'), ('SE', 'SOUTH EAST'),
                                   ('SW', 'SOUTH WEST'), ('UP', 'UPPER'), ('W', 'WEST')])

states = OrderedDict([('ACT', 'AUSTRALIAN CAPITAL TERRITORY'), ('NSW', 'NEW SOUTH WALES'),
                      ('NT', 'NORTHERN TERRITORY'), ('OT', 'OTHER TERRITORIES'), ('QLD', 'QUEENSLAND'),
                      ('SA', 'SOUTH AUSTRALIA'), ('TAS', 'TASMANIA'), ('VIC', 'VICTORIA'),
                      ('WA', 'WESTERN AUSTRALIA')])

# Abbreviaitons from METeOR identifier: 429387
# see https://meteor.aihw.gov.au/content/index.phtml/itemId/429387/pageDefinitionItemId/tag.MeteorPrinterFriendlyPage
street_type_abbreviation = {'ACCESS': 'ACCS', 'ALLEY': 'ALLY', 'ALLEYWAY': 'ALWY', 'AMBLE': 'AMBL', 'APPROACH': 'APP',
                            'ARCADE': 'ARC', 'ARTERIAL': 'ARTL', 'ARTERY': 'ARTY', 'AVENUE': 'AV', 'BANAN': 'BA',
                            'BEND': 'BEND', 'BOARDWALK': 'BWLK', 'BOULEVARD': 'BVD', 'BRACE': 'BR', 'BRAE': 'BRAE',
                            'BREAK': 'BRK', 'BROW': 'BROW', 'BYPASS': 'BYPA', 'BYWAY': 'BYWY', 'CAUSEWAY': 'CSWY',
                            'CENTRE': 'CTR', 'CHASE': 'CH', 'CIRCLE': 'CIR', 'CIRCUIT': 'CCT', 'CIRCUS': 'CRCS',
                            'CLOSE': 'CL', 'CONCOURSE': 'CON', 'COPSE': 'CPS', 'CORNER': 'CNR', 'COURT': 'CT',
                            'COURTYARD': 'CTYD', 'COVE': 'COVE', 'CRESCENT': 'CR', 'CREST': 'CRST', 'CROSS': 'CRSS',
                            'CUL-DE-SAC': 'CSAC', 'CUTTING': 'CUTT', 'DALE': 'DALE', 'DIP': 'DIP', 'DRIVE': 'DR',
                            'DRIVEWAY': 'DVWY', 'EDGE': 'EDGE', 'ELBOW': 'ELB', 'END': 'END', 'ENTRANCE': 'ENT',
                            'ESPLANADE': 'ESP', 'EXPRESSWAY': 'EXP', 'FAIRWAY': 'FAWY', 'FOLLOW': 'FOLW',
                            'FOOTWAY': 'FTWY', 'FORMATION': 'FORM', 'FREEWAY': 'FWY', 'FRONTAGE': 'FRTG',
                            'GAP': 'GAP', 'GARDENS': 'GDNS', 'GATE': 'GTE', 'GLADE': 'GLDE', 'GLEN': 'GLEN',
                            'GRANGE': 'GRA', 'GREEN': 'GRN', 'GROVE': 'GR', 'HEIGHTS': 'HTS', 'HIGHROAD': 'HIRD',
                            'HIGHWAY': 'HWY', 'HILL': 'HILL', 'INTERCHANGE': 'INTG', 'JUNCTION': 'JNC', 'KEY': 'KEY',
                            'LANE': 'LANE', 'LANEWAY': 'LNWY', 'LINE': 'LINE', 'LINK': 'LINK', 'LOOKOUT': 'LKT',
                            'LOOP': 'LOOP', 'MALL': 'MALL', 'MEANDER': 'MNDR', 'MEWS': 'MEWS', 'MOTORWAY': 'MTWY',
                            'NOOK': 'NOOK', 'OUTLOOK': 'OTLK', 'PARADE': 'PDE', 'PARKWAY': 'PWY', 'PASS': 'PASS',
                            'PASSAGE': 'PSGE', 'PATH': 'PATH', 'PATHWAY': 'PWAY', 'PIAZZA': 'PIAZ', 'PLAZA': 'PLZA',
                            'POCKET': 'PKT', 'POINT': 'PNT', 'PORT': 'PORT', 'PROMENADE': 'PROM', 'QUADRANT': 'QDRT',
                            'QUAYS': 'QYS', 'RAMBLE': 'RMBL', 'REST': 'REST', 'RETREAT': 'RTT', 'RIDGE': 'RDGE',
                            'RISE': 'RISE', 'ROAD': 'RD', 'ROTARY': 'RTY', 'ROUTE': 'RTE', 'ROW': 'ROW', 'RUE': 'RUE',
                            'SERVICEWAY': 'SVWY', 'SHUNT': 'SHUN', 'SPUR': 'SPUR', 'SQUARE': 'SQ', 'STREET': 'ST',
                            'SUBWAY': 'SBWY', 'TARN': 'TARN', 'TERRACE': 'TCE', 'THOROUGHFARE': 'THFR',
                            'TOLLWAY': 'TLWY', 'TOP': 'TOP', 'TOR': 'TOR', 'TRACK': 'TRK', 'TRAIL': 'TRL',
                            'TURN': 'TURN', 'UNDERPASS': 'UPAS', 'VALE': 'VALE', 'VIADUCT': 'VIAD', 'VIEW': 'VIEW',
                            'VISTA': 'VSTA', 'WALK': 'WALK', 'WALKWAY': 'WKWY', 'WHARF': 'WHRF', 'WYND': 'WYND'}

ordinal_words = [
    'first', 'second', 'third', 'fourth', 'fifth', 'sixth', 'seventh', 'eighth', 'ninth', 'tenth', 'eleventh',
    'twelfth', 'thirteenth', 'fourteenth', 'fifteenth', 'sixteenth', 'seventeenth', 'eighteenth', 'nineteenth',
    'twentieth', 'twenty-first', 'twenty-second', 'twenty-third', 'twenty-fourth', 'twenty-fifth', 'twenty-sixth',
    'twenty-seventh', 'twenty-eighth', 'twenty-ninth', 'thirtieth', 'thirty-first', 'thirty-second', 'thirty-third',
    'thirty-fourth', 'thirty-fifth', 'thirty-sixth', 'thirty-seventh', 'thirty-eighth', 'thirty-ninth', 'fortieth',
    'forty-first', 'forty-second', 'forty-third', 'forty-fourth', 'forty-fifth', 'forty-sixth', 'forty-seventh',
    'forty-eighth', 'forty-ninth', 'fiftieth', 'fifty-first', 'fifty-second', 'fifty-third', 'fifty-fourth',
    'fifty-fifth', 'fifty-sixth', 'fifty-seventh', 'fifty-eighth', 'fifty-ninth', 'sixtieth', 'sixty-first',
    'sixty-second', 'sixty-third', 'sixty-fourth', 'sixty-fifth', 'sixty-sixth', 'sixty-seventh', 'sixty-eighth',
    'sixty-ninth', 'seventieth', 'seventy-first', 'seventy-second', 'seventy-third', 'seventy-fourth', 'seventy-fifth',
    'seventy-sixth', 'seventy-seventh', 'seventy-eighth', 'seventy-ninth', 'eightieth', 'eighty-first', 'eighty-second',
    'eighty-third', 'eighty-fourth', 'eighty-fifth', 'eighty-sixth', 'eighty-seventh', 'eighty-eighth', 'eighty-ninth',
    'ninetieth', 'ninety-first', 'ninety-second', 'ninety-third', 'ninety-fourth', 'ninety-fifth', 'ninety-sixth',
    'ninety-seventh', 'ninety-eighth', 'ninety-ninth', 'one hundredth'
]

cardinal_words = [
    'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve', 'thirteen',
    'fourteen', 'fifteen', 'sixteen', 'seventeen', 'eighteen', 'nineteen', 'twenty', 'twenty-one', 'twenty-two',
    'twenty-three', 'twenty-four', 'twenty-five', 'twenty-six', 'twenty-seven', 'twenty-eight', 'twenty-nine', 'thirty',
    'thirty-one', 'thirty-two', 'thirty-three', 'thirty-four', 'thirty-five', 'thirty-six', 'thirty-seven',
    'thirty-eight', 'thirty-nine', 'forty', 'forty-one', 'forty-two', 'forty-three', 'forty-four', 'forty-five',
    'forty-six', 'forty-seven', 'forty-eight', 'forty-nine', 'fifty', 'fifty-one', 'fifty-two', 'fifty-three',
    'fifty-four', 'fifty-five', 'fifty-six', 'fifty-seven', 'fifty-eight', 'fifty-nine', 'sixty', 'sixty-one',
    'sixty-two', 'sixty-three', 'sixty-four', 'sixty-five', 'sixty-six', 'sixty-seven', 'sixty-eight', 'sixty-nine',
    'seventy', 'seventy-one', 'seventy-two', 'seventy-three', 'seventy-four', 'seventy-five', 'seventy-six',
    'seventy-seven', 'seventy-eight', 'seventy-nine', 'eighty', 'eighty-one', 'eighty-two', 'eighty-three',
    'eighty-four', 'eighty-five', 'eighty-six', 'eighty-seven', 'eighty-eight', 'eighty-nine', 'ninety', 'ninety-one',
    'ninety-two', 'ninety-three', 'ninety-four', 'ninety-five', 'ninety-six', 'ninety-seven', 'ninety-eight',
    'ninety-nine', 'one hundred'
]


def _lookup(t: str, types: [str]) -> int:
    """
    Looks up the value, t, from the array of types
    :param t: value to lookup
    :param types: list of types from which to lookup
    :return: an integer value > 0 if found, or 0 if not found
    """
    try:
        return types.index(t.strip().upper()) + 1
    except ValueError:
        return 0


def _reverse_lookup(idx: int, types: [str]) -> str:
    """
    Converts an integer value back to the string representation
    :param idx: integer value
    :param types: list of types
    :return: the string value or None if not found (idx == 0)
    """
    if idx == 0:
        return ''
    else:
        return types[idx - 1]


def lookup_state(state: Union[str, int], reverse_lookup=False) -> Union[str, int]:
    """
    Converts the representation for the geographic state
    :param state: string or int to lookup
    :param reverse_lookup: True if converting int to string, or False if string to int
    :return: the encoded value
    """
    if reverse_lookup:
        return _reverse_lookup(state, list(states.keys()))
    return _lookup(state, list(states.keys()))


def expand_state(state: str) -> str:
    """
    Converts an abbreviated state name to the full name, e.g. "VIC" -> "VICTORIA"
    :param state: abbreviated state
    :return: full state
    """
    return states[state.strip().upper()]


def lookup_street_type(street_type: Union[str, int], reverse_lookup=False) -> Union[str, int]:
    """
    Converts the representation for the street type
    :param street_type: string or int to lookup
    :param reverse_lookup: True if converting int to string, or False if string to int
    :return: the encoded value
    """
    if reverse_lookup:
        return _reverse_lookup(street_type, street_types)
    return _lookup(street_type, street_types)


def abbreviate_street_type(street_type: str) -> str:
    """
    Converts an full street type to the abbreviated name, e.g. "STREET" -> "ST"
    :param street_type: full street type
    :return: abbreviated street type
    """
    try:
        return street_type_abbreviation[street_type.strip().upper()]
    except KeyError:
        return street_type


def lookup_street_suffix(street_suffix: Union[str, int], reverse_lookup=False) -> Union[str, int]:
    """
    Converts the representation for the street type suffix
    :param street_suffix: string or int to lookup
    :param reverse_lookup: True if converting int to string, or False if string to int
    :return: the encoded value
    """
    if reverse_lookup:
        return _reverse_lookup(street_suffix, list(street_suffix_types.keys()))
    return _lookup(street_suffix, list(street_suffix_types.keys()))


def expand_street_type_suffix(street_suffix: str) -> str:
    """
    Converts an abbreviated street suffix to the full name, e.g. "N" -> "NORTH"
    :param street_suffix: abbreviated street suffix
    :return: full street suffix
    """
    try:
        return street_suffix_types[street_suffix.strip().upper()]
    except KeyError:
        return street_suffix


def lookup_level_type(level_type: Union[str, int], reverse_lookup=False) -> Union[str, int]:
    """
    Converts the representation for the level type
    :param level_type: string or int to lookup
    :param reverse_lookup: True if converting int to string, or False if string to int
    :return: the encoded value
    """
    if reverse_lookup:
        return _reverse_lookup(level_type, level_types)
    return _lookup(level_type, level_types)


def lookup_flat_type(flat_type: Union[str, int], reverse_lookup=False) -> Union[str, int]:
    """
    Converts the representation for the flat type
    :param flat_type: string or int to lookup
    :param reverse_lookup: True if converting int to string, or False if string to int
    :return: the encoded value
    """
    if reverse_lookup:
        return _reverse_lookup(flat_type, flat_types)
    return _lookup(flat_type, flat_types)


# Adapted from http://code.activestate.com/recipes/576888-format-a-number-as-an-ordinal/
def num2word(value, output='ordinal_words'):
    """
    Converts zero or a *postive* integer (or their string
    representations) to an ordinal/cardinal value.
    :param value: the number to convert
    :param output: one of 'ordinal_words', 'ordinal', 'cardinal'
    """
    try:
        value = int(value)
    except ValueError:
        return value

    assert output in (
    'ordinal_words', 'ordinal', 'cardinal'), "`output` must be one of 'ordinal_words', 'ordinal' or 'cardinal'"

    if output == 'ordinal_words' and (0 < value < 100):
        val = ordinal_words[value - 1]
    elif output == 'ordinal_words':
        raise ValueError("'ordinal_words' only supported between 1 and 100")
    elif output == 'ordinal':
        if value % 100 // 10 != 1:
            if value % 10 == 1:
                val = u"%d%s" % (value, "st")
            elif value % 10 == 2:
                val = u"%d%s" % (value, "nd")
            elif value % 10 == 3:
                val = u"%d%s" % (value, "rd")
            else:
                val = u"%d%s" % (value, "th")
        else:
            val = u"%d%s" % (value, "th")
    else:
        val = cardinal_words[value - 1]

    return val.upper()

In [105]:
from typing import Dict, Optional

import tensorflow as tf

from addressnet.dataset import vocab, n_labels


def model_fn(features: Dict[str, tf.Tensor], labels: tf.Tensor, mode: str, params) -> tf.estimator.EstimatorSpec:
    """
    The AddressNet model function suitable for tf.estimator.Estimator
    :param features: a dictionary containing tensors for the encoded_text and lengths
    :param labels: a label for each character designating its position in the address
    :param mode: indicates whether the model is being trained, evaluated or used in prediction mode
    :param params: model hyperparameters, including rnn_size and rnn_layers
    :return: the appropriate tf.estimator.EstimatorSpec for the model mode
    """
    encoded_text, lengths = features['encoded_text'], features['lengths']
    rnn_size = params.get("rnn_size", 128)
    rnn_layers = params.get("rnn_layers", 3)

    embeddings = tf.Variable("embeddings", dtype=tf.float32, initializer=tf.random_normal(shape=(len(vocab), 8)))
    encoded_strings = tf.nn.embedding_lookup(embeddings, encoded_text)

    logits, loss = nnet(encoded_strings, lengths, rnn_layers, rnn_size, labels, mode == tf.estimator.ModeKeys.TRAIN)

    predicted_classes = tf.argmax(logits, axis=2)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'class_ids': predicted_classes,
            'probabilities': tf.nn.softmax(logits)
        }
        return tf.estimator.EstimatorSpec(mode, predictions=predictions)

    if mode == tf.estimator.ModeKeys.EVAL:
        metrics = {}
        return tf.estimator.EstimatorSpec(
            mode, loss=loss, eval_metric_ops=metrics)

    if mode == tf.estimator.ModeKeys.TRAIN:
        train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss, global_step=tf.train.get_global_step())
        return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)


def nnet(encoded_strings: tf.Tensor, lengths: tf.Tensor, rnn_layers: int, rnn_size: int, labels: tf.Tensor = None,
         training: bool = True) -> (tf.Tensor, Optional[tf.Tensor]):
    """
    Generates the RNN component of the model
    :param encoded_strings: a tensor containing the encoded strings (embedding vectors)
    :param lengths: a tensor of string lengths
    :param rnn_layers: number of layers to use in the RNN
    :param rnn_size: number of units in each layer
    :param labels: labels for each character in the string (optional)
    :param training: if True, dropout will be enabled on the RNN
    :return: logits and loss (loss will be None if labels is not provided)
    """

    def rnn_cell():
        probs = 0.8 if training else 1.0
        return tf.contrib.rnn.DropoutWrapper(tf.contrib.cudnn_rnn.CudnnCompatibleGRUCell(rnn_size),
                                             state_keep_prob=probs, output_keep_prob=probs)

    rnn_cell_fw = tf.nn.rnn_cell.MultiRNNCell([rnn_cell() for _ in range(rnn_layers)])
    rnn_cell_bw = tf.nn.rnn_cell.MultiRNNCell([rnn_cell() for _ in range(rnn_layers)])

    (rnn_output_fw, rnn_output_bw), states = tf.nn.bidirectional_dynamic_rnn(rnn_cell_fw, rnn_cell_bw, encoded_strings,
                                                                             lengths, dtype=tf.float32)
    rnn_output = tf.concat([rnn_output_fw, rnn_output_bw], axis=2)
    logits = tf.layers.dense(rnn_output, n_labels, activation=tf.nn.elu)

    loss = None
    if labels is not None:
        mask = tf.sequence_mask(lengths, dtype=tf.float32)
        loss = tf.losses.softmax_cross_entropy(labels, logits, weights=mask)
    return logits, loss

In [106]:
import os
from typing import Dict, List, Union
import textdistance
import tensorflow as tf

from addressnet.dataset import predict_input_fn, labels_list
from addressnet.lookups import street_types, street_type_abbreviation, states, street_suffix_types, flat_types, \
    level_types
from addressnet.model import model_fn
from functools import lru_cache


def _get_best_match(target: str, candidates: Union[List[str], Dict[str, str]], keep_idx: int = 0) -> str:
    """
    Returns the most similar string to the target given a dictionary or list of candidates. If a dictionary is provided,
    the keys and values are compared to the target, but only the requested component of the matched tuple is returned.
    :param target: the target string to be matched
    :param candidates: a key-value dictionary or list of strings
    :param keep_idx: 0 to return the key, 1 to return the value of the best match (no effect if list is supplied)
    :return: the matched string
    """
    max_sim = None
    best = None

    try:
        candidates_list = candidates.items()
    except AttributeError:
        candidates_list = [(i,) for i in candidates]
        keep_idx = 0

    for kv in candidates_list:
        if target in kv:
            return kv[keep_idx]

        for i in kv:
            similarity = _str_sim(i, target)
            if max_sim is None or similarity > max_sim:
                best = kv[keep_idx]
                max_sim = similarity
    return best


def _str_sim(a, b, fn=textdistance.jaro_winkler):
    """
    Wrapper function for the string similarity function
    :param a: a string to compare
    :param b: another string to compare
    :param fn: the string similarity function from the textdistance package
    :return: the similarity ratio
    """
    return fn.normalized_similarity(a.lower(), b.lower())


def normalise_state(s: str) -> str:
    """
    Converts the state parameter to a standard non-abbreviated form
    :param s: state string
    :return: state name in full
    """
    if s in states:
        return states[s]
    return _get_best_match(s, states, keep_idx=1)


def normalise_street_type(s: str) -> str:
    """
    Converts the street type parameter to a standard non-abbreviated form
    :param s: street type string
    :return: street type in full
    """
    if s in street_types:
        return s
    return _get_best_match(s, street_type_abbreviation, keep_idx=0)


def normalise_street_suffix(s: str) -> str:
    """
    Converts the street suffix parameter to a standard non-abbreviated form
    :param s: street suffix string
    :return: street suffix in full
    """
    if s in street_suffix_types:
        return street_suffix_types[s]
    return _get_best_match(s, street_suffix_types, keep_idx=1)


def normalise_flat_type(s: str) -> str:
    """
    Converts the flat type parameter to a standard non-abbreviated form
    :param s: flat type string
    :return: flat type in full
    """
    if s in flat_types:
        return s
    return _get_best_match(s, flat_types)


def normalise_level_type(s: str) -> str:
    """
    Converts the level type parameter to a standard non-abbreviated form
    :param s: level type string
    :return: level type in full
    """
    if s in level_types:
        return s
    return _get_best_match(s, level_types)


@lru_cache(maxsize=2)
def _get_estimator(model_fn, model_dir):
    return tf.estimator.Estimator(model_fn=model_fn,
                                  model_dir=model_dir)


def predict_one(address: str, model_dir: str = None) -> Dict[str, str]:
    """
    Segments a given address into its components and attempts to normalise categorical components,
    e.g. state, street type
    :param address: the input address string
    :param model_dir: path to trained model
    :return: a dictionary with the address components separated
    """
    return next(predict([address], model_dir))


def predict(address: List[str], model_dir: str = None) -> List[Dict[str, str]]:
    """
    Segments a set of addresses into their components and attempts to normalise categorical components,
    e.g. state, street type
    :param address: the input list of address strings
    :param model_dir: path to trained model
    :return: a list of dictionaries with the address components separated
    """
    if model_dir is None:
        model_dir = os.path.join(os.path.dirname(__file__), 'pretrained')
    assert os.path.isdir(model_dir), "invalid model_dir provided: %s" % model_dir
    address_net_estimator = _get_estimator(model_fn, model_dir)
    result = address_net_estimator.predict(predict_input_fn(address))
    class_names = [l.replace("_code", "") for l in labels_list]
    class_names = [l.replace("_abbreviation", "") for l in class_names]
    for addr, res in zip(address, result):
        mappings = dict()
        for char, class_id in zip(addr.upper(), res['class_ids']):
            if class_id == 0:
                continue
            cls = class_names[class_id - 1]
            mappings[cls] = mappings.get(cls, "") + char

        if 'state' in mappings:
            mappings['state'] = normalise_state(mappings['state'])
        if 'street_type' in mappings:
            mappings['street_type'] = normalise_street_type(mappings['street_type'])
        if 'street_suffix' in mappings:
            mappings['street_suffix'] = normalise_street_suffix(mappings['street_suffix'])
        if 'flat_type' in mappings:
            mappings['flat_type'] = normalise_flat_type(mappings['flat_type'])
        if 'level_type' in mappings:
            mappings['level_type'] = normalise_level_type(mappings['level_type'])

        yield mappings

In [107]:
import random
import numpy as np

# Contains nearby characters on the keyboard for substitution when generating typos
character_replacement = dict()

character_replacement['a'] = 'qwsz'
character_replacement['b'] = 'nhgv '
character_replacement['c'] = 'vfdx '
character_replacement['d'] = 'fresxc'
character_replacement['e'] = 'sdfr43ws'
character_replacement['f'] = 'gtrdcv'
character_replacement['g'] = 'hytfvb'
character_replacement['h'] = 'juytgbn'
character_replacement['i'] = 'ujklo98'
character_replacement['j'] = 'mkiuyhn'
character_replacement['k'] = 'jm,loij'
character_replacement['l'] = 'k,.;pok'
character_replacement['m'] = 'njk, '
character_replacement['n'] = 'bhjm '
character_replacement['o'] = 'plki90p'
character_replacement['p'] = 'ol;[-0o'
character_replacement['q'] = 'asw21'
character_replacement['r'] = 'tfde45'
character_replacement['s'] = 'dxzawe'
character_replacement['t'] = 'ygfr56'
character_replacement['u'] = 'ijhy78'
character_replacement['v'] = 'cfgb '
character_replacement['w'] = 'saq23e'
character_replacement['x'] = 'zsdc'
character_replacement['y'] = 'uhgt67'
character_replacement['z'] = 'xsa'
character_replacement['1'] = '2q'
character_replacement['2'] = '3wq1'
character_replacement['3'] = '4ew2'
character_replacement['4'] = '5re3'
character_replacement['5'] = '6tr4'
character_replacement['6'] = '7yt5'
character_replacement['7'] = '8uy6'
character_replacement['8'] = '9iu7'
character_replacement['9'] = '0oi8'
character_replacement['0'] = '-po9'


def generate_typo(s: str, sub_rate: float = 0.01, del_rate: float = 0.005, dupe_rate: float = 0.005,
                  transpose_rate: float = 0.01) -> str:
    """
    Generates a new string containing some plausible typos
    :param s: the input string
    :param sub_rate: character substitution rate (0 < x < 1)
    :param del_rate: character deletion rate (0 < x < 1)
    :param dupe_rate: character duplication rate (0 < x < 1)
    :param transpose_rate: character transposition rate (0 < x < 1)
    :return: the string with typos
    """
    if len(s) == 0:
        return s

    new_string = list()
    for i, char in enumerate(s.lower()):

        # Decide what to do
        do = np.random.uniform(size=(4,))
        do_swap = do[0] < sub_rate
        do_delete = do[1] < del_rate
        do_duplicate = do[2] < dupe_rate
        do_transpose = do[3] < transpose_rate

        if do_swap and char in character_replacement:
            # Exchange the character for a randomly selected replacement of nearby keys
            new_string.append(random.choice(character_replacement[char]))
        elif do_delete:
            # Don't include this character in the replacement string
            continue
        elif do_duplicate:
            # Add this character twice to the new string
            new_string.extend([char] * 2)
        elif do_transpose and len(new_string) > 0:
            # Swap this and the previous character
            new_string.append(new_string[-1])
            new_string[-2] = char
        else:
            # Keep the character
            new_string.append(char)

    # if an empty string is generated, give it another go
    if len(new_string) == 0:
        return generate_typo(s, sub_rate, del_rate, dupe_rate, transpose_rate)

    return ''.join(new_string)

In [108]:
# Import the predict_one function
from addressnet.predict import predict_one

# Run the prediction on a sample address
if __name__ == "__main__":
                  print(predict_one("casa del gelato, 10A 24-26 high street road mount waverley vic 3183"))

AttributeError: ignored

In [None]:
from addressnet.predict import predict_one

if __name__ == "__main__":
    print(predict_one("casa del gelato, 10A 24-26 high street road mount waverley vic 3183"))