<a href="https://colab.research.google.com/github/BBVA/mercury-settrie/blob/master/notebooks/settrie_benchmark_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A Benchmark comparing mercury.dynamics.settrie and the python implementation

We import some python utils for timing ...

In [1]:
import os, time, psutil

We import the class `SetTrieMap()` the python implementation of settrie that can be found at https://github.com/mmihaltz/pysettrie

In [2]:
#@title
"""Set-trie-based associative array. TODO: Redo doc and ask Oscar source code origin"""
import sys
import sortedcontainers


class SetTrieMap:
    """Associative array where keys are sets.

    Mapping container for efficient storage of key-value pairs where
    the keys are sets. Uses an efficient trie implementation. Supports querying
    for values associated to subsets or supersets of stored key sets.

    Examples:
    >>> from settrie import SetTrieMap
    ... m = SetTrieMap()
    ... m.assign({1,2}, 'A')
    ... m.assign({1,2,3}, 'B')
    ... m.assign({2,3,5}, 'C')
    ... m
    [({1, 2}, 'A'), ({1, 2, 3}, 'B'), ({2, 3, 5}, 'C')]
    >>> m.get({1,2,3})
    'B'
    >>> m.get({1, 2, 3, 4}, 'Nope!')
    'Nope!'
    >>> list(m.keys())
    [{1, 2}, {1, 2, 3}, {2, 3, 5}]
    >>> m.supersets( {1,2} )
    [({1, 2}, 'A'), ({1, 2, 3}, 'B')]
    >>> m.supersets({1, 2}, mode='keys')
    [{1, 2}, {1, 2, 3}]
    >>> m.supersets({1, 2}, mode='values')
    ['A', 'B']
    """

    class Node:
        """Node object used by SetTrieMap.
        """

        def __init__(self, data=None, trie=None):
            # child nodes a.k.a. children
            self.children = sortedcontainers.SortedList()
            # flag_last: if True, this is the last element of a key set store a
            # member element of the key set. Must be a hashable
            # (i.e. hash(data) should work), that makes it comparable/orderable
            # (i.e. data1 < data2 should work; see
            # https://wiki.python.org/moin/HowTo/Sorting/) type.
            self.flag_last = False
            self.data = data
            self.data_hash = trie.idx_func(data)
            # the value associated to the key set if flag_last == True,
            # otherwise None
            self.value = None

        # comparison operators to support rich comparisons, sorting
        # etc. using self.data as key
        def __eq__(self, other):
            return self.data_hash == other.data_hash

        def __ne__(self, other):
            return self.data_hash != other.data_hash

        def __lt__(self, other):
            return self.data_hash < other.data_hash

        def __le__(self, other):
            return self.data_hash <= other.data_hash

        def __gt__(self, other):
            return self.data_hash > other.data_hash

        def __ge__(self, other):
            return self.data_hash >= other.data_hash

    def __init__(self, iterable=None, idx_func=None):
        """Set up this SetTrieMap object.  If iterable is specified, it must
           be an iterable of (keyset, value) pairs from which set-trie
           is populated.
        """
        self.idx_func = idx_func if idx_func else hash
        self.root = SetTrieMap.Node(trie=self)

        if iterable is not None:
            for key, value in iterable:
                self.assign(key, value)

    def assign(self, akey, avalue):
        """Add key akey with associated value avalue to the container.
           akey must be a sortable and iterable container type."""
        self._assign(self.root, iter(sorted(akey, key=self.idx_func)), avalue)

    def _assign(self, node, it, val):
        """Recursive function used by self.assign()."""
        try:
            data = next(it)
            nextnode = None
            try:
                # find first child with this data
                nextnode = node.children[node.children.index(
                    SetTrieMap.Node(data, trie=self))]
            except ValueError:  # not found
                nextnode = SetTrieMap.Node(data, trie=self)  # create new node
                node.children.add(nextnode)  # add to children & sort
            self._assign(nextnode, it, val)  # recurse
        except StopIteration:  # end of set to add
            node.flag_last = True
            node.value = val

    def contains(self, keyset):
        """Returns True iff this set-trie contains set keyset as a key."""
        return self._contains(self.root, iter(sorted(keyset, key=self.idx_func)))

    def __contains__(self, keyset):
        """Returns True iff this set-trie contains set keyset as a key.

        Examples:
        This method definition allows the use of the 'in' operator.
        example:

        >>> t = SetTrieMap()
        ... t.assign({1, 3}, 'M' )
        ... {1, 3} in t
        True
        """
        return self.contains(keyset)

    def _contains(self, node, it):
        """Recursive function used by self.contains()."""
        try:
            data = next(it)
            try:
                # find first child with this data
                matchnode = node.children[node.children.index(
                    SetTrieMap.Node(data, trie=self))]
                return self._contains(matchnode, it)  # recurse
            except ValueError:  # not found
                return False
        except StopIteration:
            return node.flag_last

    def get(self, keyset, default=None):
        """Return the value associated to keyset if keyset is in this
        SetTrieMap, else default.
        """
        return self._get(self.root, iter(sorted(keyset, key=self.idx_func)), default)

    def _get(self, node, it, default):
        """Recursive function used by self.get()."""
        try:
            data = next(it)
            try:
                # find first child with this data
                matchnode = node.children[node.children.index(
                    SetTrieMap.Node(data, trie=self))]
                return self._get(matchnode, it, default)  # recurse
            except ValueError:  # not found
                return default
        except StopIteration:
            return (node.value if node.flag_last else default)

    def hassuperset(self, aset):
        """Returns True iff there is at least one key set in this SetTrieMap
           that is the superset of set aset.
        """
        return self._hassuperset(self.root, list(sorted(aset, key=self.idx_func)), 0)

    def _hassuperset(self, node, setarr, idx):
        """Used by hassuperset()."""
        if idx > len(setarr) - 1:
            return True
        found = False
        for child in node.children:
            # don't go to subtrees where current element cannot be
            if child.data_hash > self.idx_func(setarr[idx]):
                break
            if child.data_hash == self.idx_func(setarr[idx]):
                found = self._hassuperset(child, setarr, idx + 1)
            else:
                found = self._hassuperset(child, setarr, idx)
            if found:
                break
        return found

    def itersupersets(self, aset, mode=None):
        """Return an iterator over all (keyset, value) pairs from this
           SetTrieMap for which set keyset is a superset (proper or
           not proper) of set aset.  If mode is not None, the
           following values are allowed:

           mode='keys': return an iterator over only the keysets that
                        are supersets of aset is returned
           mode='values': return an iterator over only the values that
                          are associated to keysets that are supersets
                          of aset

           If mode is neither of 'keys', 'values' or None, behavior is
           equivalent to mode=None.

        """
        path = []
        return self._itersupersets(self.root, list(sorted(aset, key=self.idx_func)), 0,
                                   path, mode)

    def _itersupersets(self, node, setarr, idx, path, mode):
        """Used by itersupersets()."""
        if node.data is not None:
            path.append(node.data)
        if node.flag_last and idx > len(setarr) - 1:
            if mode == 'keys':
                yield set(path)
            elif mode == 'values':
                yield node.value
            else:
                yield (set(path), node.value)
        # we still have elements of aset to find
        if idx <= len(setarr) - 1:
            for child in node.children:
                # don't go to subtrees where current element cannot be
                if child.data_hash > self.idx_func(setarr[idx]):
                    break
                if child.data_hash == self.idx_func(setarr[idx]):
                    yield from self._itersupersets(child,
                                                   setarr,
                                                   idx + 1,
                                                   path,
                                                   mode)
                else:
                    yield from self._itersupersets(child,
                                                   setarr, idx,
                                                   path, mode)
        # no more elements to find: just traverse this subtree to get
        # all supersets
        else:
            for child in node.children:
                yield from self._itersupersets(child, setarr,
                                               idx, path, mode)
        if node.data is not None:
            path.pop()

    def supersets(self, aset, mode=None):
        """Return a list containing pairs of (keyset, value) for which keyset
           is superset of set aset.

           Parameter mode: see documentation for itersupersets().
        """
        return list(self.itersupersets(aset, mode))

    def hassubset(self, aset):
        """Return True iff there is at least one set in this SetTrieMap that
           is the (proper or not proper) subset of set aset.
        """
        return self._hassubset(self.root, list(sorted(aset, key=self.idx_func)), 0)

    def _hassubset(self, node, setarr, idx):
        """Used by hassubset()."""
        if node.flag_last:
            return True
        if idx > len(setarr) - 1:
            return False
        found = False
        try:
            c = node.children.index(SetTrieMap.Node(setarr[idx], trie=self))
            found = self._hassubset(node.children[c], setarr, idx + 1)
        except ValueError:
            pass
        if not found:
            return self._hassubset(node, setarr, idx + 1)
        else:
            return True

    def itersubsets(self, aset, mode=None):
        """Return an iterator over pairs (keyset, value) from this SetTrieMap
           for which keyset is (proper or not proper) subset of set aset.
           If mode is not None, the following values are allowed:

           mode='keys': return an iterator over only the keysets that
                        are subsets of aset is returned

           mode='values': return an iterator over only the values that
                          are associated to keysets that are subsets of aset

           If mode is neither of 'keys', 'values' or None, behavior is
           equivalent to mode=None.
        """
        path = []
        return self._itersubsets(self.root, list(sorted(aset, key=self.idx_func)),
                                 0, path, mode)

    def _itersubsets(self, node, setarr, idx, path, mode):
        """Used by itersubsets()."""
        if node.data is not None:
            path.append(node.data)
        if node.flag_last:
            if mode == 'keys':
                yield set(path)
            elif mode == 'values':
                yield node.value
            else:
                yield (set(path), node.value)
        for child in node.children:
            if idx > len(setarr) - 1:
                break
            if child.data_hash == self.idx_func(setarr[idx]):
                yield from self._itersubsets(child, setarr,
                                             idx + 1, path, mode)
            else:
                # advance in search set until we find child (or get to
                # the end, or get to an element > child)
                jdx = idx + 1
                while jdx < len(setarr) and child.data_hash >= self.idx_func(setarr[jdx]):
                    if child.data == setarr[jdx]:
                        yield from self._itersubsets(child,
                                                     setarr,
                                                     jdx, path,
                                                     mode)
                        break
                    jdx += 1
        if node.data is not None:
            path.pop()

    def subsets(self, aset, mode=None):
        """Return a list of (keyset, value) pairs from this set-trie
           for which keyset is (proper or not proper) subset of set aset.
           Parameter mode: see documentation for itersubsets().
        """
        return list(self.itersubsets(aset, mode))

    def iter(self, mode=None):
        """Returns an iterator to all (keyset, value) pairs stored in this
           SetTrieMap (using pre-order tree traversal).  The pairs are
           returned sorted to their keys, which are also sorted.  If
           mode is not None, the following values are allowed:

           mode='keys': return an iterator over only the keysets that
                        are subsets of aset

           mode='values': return an iterator over only the values that
                          are associated to keysets that are subsets
                          of aset

           If mode is neither of 'keys', 'values' or None, behavior is
           equivalent to mode=None.
        """
        path = []
        yield from SetTrieMap._iter(self.root, path, mode)

    def keys(self):
        """Alias for self.iter(mode='keys')."""
        return self.iter(mode='keys')

    def values(self):
        """Alias for self.iter(mode='values')."""
        return self.iter(mode='values')

    def items(self):
        """Alias for self.iter(mode=None)."""
        return self.iter(mode=None)

    def __iter__(self):
        """Same as self.iter(mode='keys')."""
        return self.keys()

    @staticmethod
    def _iter(node, path, mode):
        """Recursive function used by self.iter()."""
        if node.data is not None:
            path.append(node.data)
        if node.flag_last:
            if mode == 'keys':
                yield set(path)
            elif mode == 'values':
                yield node.value
            else:
                yield (set(path), node.value)
        for child in node.children:
            yield from SetTrieMap._iter(child, path, mode)
        if node.data is not None:
            path.pop()

    def aslist(self):
        """Return a list containing all the (keyset, value) pairs stored in
           this SetTrieMap.  The pairs are returned sorted to their
           keys, which are also sorted.
        """
        return list(self.iter())

    def printtree(self, tabchr=' ', tabsize=2, stream=sys.stdout):
        """Print a mirrored 90-degree rotation of the nodes in this SetTrieMap
           to stream (default: sys.stdout).  Nodes marked as flag_last
           are trailed by the '#' character.  tabchr and tabsize
           determine the indentation: at tree level n, n*tabsize
           tabchar characters will be used.  Associated values are
           printed after ': ' trailing flag_last=True nodes.
        """
        self._printtree(self.root, 0, tabchr, tabsize, stream)

    @staticmethod
    def _printtree(node, level, tabchr, tabsize, stream):
        """Used by self.printTree(), recursive preorder traverse and printing
           of trie node
        """
        print((str(node.data).rjust(len(repr(node.data)) + level * tabsize,
                                    tabchr) + (': {}'.format(repr(node.value)) if
                                               node.flag_last else '')),
              file=stream)
        for child in node.children:
            SetTrieMap._printtree(child, level + 1, tabchr, tabsize, stream)

    def __str__(self):
        """Returns str(self.aslist())."""
        return str(self.aslist())

    def __repr__(self):
        """Returns str(self.aslist())."""
        return str(self.aslist())


... and the equivalent class for mercury.dynamics

In [3]:
try:
    import settrie

except ModuleNotFoundError:
    !pip install mercury-settrie

    import settrie

In [4]:
settrie.__version__

'1.4.2'

We instance one object of each `stm` for the pure python implementation and `st` from mercury dynamics.

In [5]:
stm = SetTrieMap()

In [6]:
st = settrie.SetTrie()

## The dataset

This dataset is a folder containing lots of text documents created randomly by composing the syllables:

`['bla' , 'co', 'doe', 'fi', 'gru', 'ho', 'je', 'ko', 'le', 'mu', 'no', 'pre', 're', 'sha', 'tri', 'voe', 'wha', 'ye', 'zu']`

The original datasets where in the size of hundredths of thousands of documents, the pure python implementation could not handle those sizes and it was reduced to 50,000.

With 50,000 documents the mercury dynamics implementation gave improved performance of around x200 to x300 and x20 in RAM size for the harder queries.

This dataset is a small sample (1,500 documents) and the gains are around x100 to x150 in the hardest queries.


In [7]:
!if [ ! -d ./data ]; then wget https://github.com/BBVA/mercury-settrie/raw/master/notebooks/settrie_data.tar.gz && tar -xf settrie_data.tar.gz; fi

## Comparing loading the dataset

Once our files are stored in the `data` folder, we use a function that measures how the Python interpreter's usage of RAM is growing. We also write `dummy_assign()`, a function that parses the files doing nothing that we use to make a fair comparison of the file loading process subtracting this common time. We also have to increase the Python interpreter's recursion limit to allow loading these files in the pure python implementation from Google colab.

In [8]:
import sys
sys.setrecursionlimit(9999)

In [9]:
def get_process_memory():
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    return mem_info.rss

In [10]:
def dummy_assign(set, key):
    pass

start_time = time.time()

for i in range(1500):
    fn = 'data/%i.txt' % (i + 1)
    with open(fn) as f:
        txt = f.readlines()[0].rstrip('\n')
        
        dummy_assign(set(txt.split(' ')), fn)
        
dummy_time = time.time() - start_time
        
print("Dummy assign time is %0.3f seconds" % (dummy_time, ))

Dummy assign time is 0.578 seconds


In [11]:
start_ram = get_process_memory()
start_time = time.time()

for i in range(1500):
    fn = 'data/%i.txt' % (i + 1)
    with open(fn) as f:
        txt = f.readlines()[0].rstrip('\n')
        
        stm.assign(set(txt.split(' ')), fn)
        
print("Time - dummy_time in %0.3f seconds" % (time.time() - start_time - dummy_time, ))
print("RAM increase in %0.1fMb" % ((get_process_memory() - start_ram)/(1024*1024), ))

Time - dummy_time in 21.080 seconds
RAM increase in 712.1Mb


In [12]:
start_ram = get_process_memory()
start_time = time.time()

for i in range(1500):
    fn = 'data/%i.txt' % (i + 1)
    with open(fn) as f:
        txt = f.readlines()[0].rstrip('\n')
        
        st.insert(set(txt.split(' ')), fn)
        
print("Time - dummy_time in %0.3f seconds" % (time.time() - start_time - dummy_time, ))
print("RAM increase in %0.1fMb" % ((get_process_memory() - start_ram)/(1024*1024), ))

Time - dummy_time in 1.525 seconds
RAM increase in 45.5Mb


## Comparing finding exact (whole document) matches

In [13]:
def stm_find(query):
    start_time = time.time()
    
    print(stm.get(query))
                
    print("Done in %0.3f seconds" % (time.time() - start_time))

In [14]:
def st_find(query):
    start_time = time.time()
    
    print(st.find(query))
                
    print("Done in %0.3f seconds" % (time.time() - start_time))

In [15]:
stm_find({'doecotri', 'cogruyedoe', 'yeho', 'bla', 'shazudoe', 'je', 'bla', 'pre'})

data/712.txt
Done in 0.001 seconds


In [16]:
st_find({'doecotri', 'cogruyedoe', 'yeho', 'bla', 'shazudoe', 'je', 'bla', 'pre'})

data/712.txt
Done in 0.001 seconds


In [17]:
stm_find({'tri', 'mupre', 'le', 'gru', 'kozugru', 'reye', 'no', 'whaprerele', 'tri', 'reblazu', 'noko', 'yewhatrihobla', 'novoe', 'jeletrifivoe', 'co'})

data/777.txt
Done in 0.002 seconds


In [18]:
st_find({'tri', 'mupre', 'le', 'gru', 'kozugru', 'reye', 'no', 'whaprerele', 'tri', 'reblazu', 'noko', 'yewhatrihobla', 'novoe', 'jeletrifivoe', 'co'})

data/777.txt
Done in 0.003 seconds


## Comparing finding partial (all words in any document) matches

In [19]:
def stm_super(query, verbose = False):
    start_time = time.time()
    
    ret = []
    
    for _, key in stm.supersets(query):
        ret.append(key)
        
    ret.sort()
    
    if verbose:
        print(ret)
    else:
        print(len(ret), 'document(s) found.')
        
    print("Done in %0.3f seconds" % (time.time() - start_time))

In [20]:
def st_super(query, verbose = False):
    start_time = time.time()
    
    ret = []
    
    for key in st.supersets(query):
        ret.append(key)
        
    ret.sort()
    
    if verbose:
        print(ret)
    else:
        print(len(ret), 'document(s) found.')
        
    print("Done in %0.3f seconds" % (time.time() - start_time))

In [21]:
stm_super({'pre', 'le', 'whamu', 'tri', 'zulenoko'}, verbose = True)

['data/1013.txt']
Done in 0.238 seconds


In [22]:
st_super({'pre', 'le', 'whamu', 'tri', 'zulenoko'}, verbose = True)

['data/1013.txt']
Done in 0.015 seconds


In [23]:
stm_super({'whagru', 'yewhako', 'blavoe'}, verbose = True)

['data/1092.txt', 'data/1400.txt', 'data/1446.txt', 'data/202.txt', 'data/314.txt', 'data/475.txt', 'data/593.txt', 'data/64.txt', 'data/759.txt', 'data/764.txt', 'data/796.txt', 'data/805.txt', 'data/891.txt']
Done in 3.151 seconds


In [24]:
st_super({'whagru', 'yewhako', 'blavoe'}, verbose = True)

['data/1092.txt', 'data/1400.txt', 'data/1446.txt', 'data/202.txt', 'data/314.txt', 'data/475.txt', 'data/593.txt', 'data/64.txt', 'data/759.txt', 'data/764.txt', 'data/796.txt', 'data/805.txt', 'data/891.txt']
Done in 0.008 seconds


In [25]:
stm_super({'whagru', 'yewhako', 'blavoe'})

13 document(s) found.
Done in 0.783 seconds


In [26]:
st_super({'whagru', 'yewhako', 'blavoe'})

13 document(s) found.
Done in 0.007 seconds


In [27]:
stm_super({'wha'})

1440 document(s) found.
Done in 2.034 seconds


In [28]:
st_super({'wha'})

1440 document(s) found.
Done in 0.019 seconds


## Comparing finding subset (the whole document can be written with a set of words) matches

In [29]:
def stm_sub(query, verbose = False):
    start_time = time.time()
    
    ret = []
    
    for _, key in stm.subsets(query):
        ret.append(key)
        
    ret.sort()
    
    if verbose:
        print(ret)
    else:
        print(len(ret), 'document(s) found.')
        
    print("Done in %0.3f seconds" % (time.time() - start_time))

In [30]:
def st_sub(query, verbose = False):
    start_time = time.time()
    
    ret = []
    
    for key in st.subsets(query):
        ret.append(key)
        
    ret.sort()
    
    if verbose:
        print(ret)
    else:
        print(len(ret), 'document(s) found.')
        
    print("Done in %0.3f seconds" % (time.time() - start_time))

In [31]:
vocabulary = {'doecotri', 'cogruyedoe', 'yeho', 'bla', 'shazudoe', 'je', 'bla', 'pre'}
vocabulary.update(['tri', 'mupre', 'le', 'gru', 'kozugru', 'reye', 'no', 'whaprerele', 'tri', 'reblazu', 'noko', 'yewhatrihobla', 'novoe', 'jeletrifivoe', 'co'])

In [32]:
stm_sub(vocabulary, verbose = True)

['data/712.txt', 'data/777.txt']
Done in 0.005 seconds


In [33]:
st_sub(vocabulary, verbose = True)

['data/712.txt', 'data/777.txt']
Done in 0.000 seconds


In [34]:
vocabulary.update(['wha', 'jehovoe', 'koyegru', 'mukoyegru', 'hono', 'gru', 'fizugru', 'lenomuwha', 'kozuwhaco', 'ho', 'pre'])

In [35]:
vocabulary.update(['trire', 'co', 'whavoe', 'noblaretripre', 'no', 'yezushadoele', 'blale', 'hodoeretriwhaye', 'le', 'shadoe'])

In [36]:
vocabulary.update(['jenoye', 'zumuhotrinofi', 'ko', 'ho', 'voebla', 'jezukofibla', 'blazu', 'lehowha', 'le', 'jezublahoko', 'kofizu', 'wha'])

In [37]:
stm_sub(vocabulary, verbose = True)

['data/218.txt', 'data/274.txt', 'data/362.txt', 'data/712.txt', 'data/777.txt']
Done in 0.004 seconds


In [38]:
st_sub(vocabulary, verbose = True)

['data/218.txt', 'data/274.txt', 'data/362.txt', 'data/712.txt', 'data/777.txt']
Done in 0.000 seconds


In [39]:
for i in range(1200):
    fn = 'data/%i.txt' % (i + 1)
    with open(fn) as f:
        txt = f.readlines()[0].rstrip('\n')
        
        vocabulary.update(txt.split(' '))

In [40]:
stm_sub(vocabulary)

1200 document(s) found.
Done in 123.723 seconds


In [41]:
st_sub(vocabulary)

1200 document(s) found.
Done in 0.589 seconds
