In [2]:
class Node(object):

    def __init__(self, word=None):
        self.word = word
        self.children = {}
        self.count = 0

    def __set__(self, instance, value):
        self.instance = value

    def __get__(self, instance, owner):
        return self.instance
    
    

class Trie(object):

    def __init__(self):
        self.root = Node('*')  # trie root
        self.oov = Node()      # node for oov values
        self.size = 0          # depth of trie

    def __set__(self, instance, value):
        self.instance = value

    def __get__(self, instance, owner):
        return self.instance

    def add(self, sequence):
        node = self.root
        node.count += 1  # total word count
        for word in sequence:
            node.children[word] = node.children.setdefault(word, Node(word))
            node = node.children[word]
            node.count += 1

    def get(self, sequence):
        node = self.root
        for word in sequence:
            node = node.children.get(word, self.oov)
        return node

    def traverse(self, node=None, sequence=None, size=None):
        sequence = sequence if sequence else []
        node = self.root if not node else node

        if not node.children:
            yield sequence

        if size:
            if len(sequence) == size:
                yield sequence

        for word, n in node.children.items():
            sequence.append(word)
            yield from self.traverse(n, sequence, size=size)
            sequence.pop()

    def v(self, size=None):
        return len(list(self.traverse(size=size)))

In [9]:
counts = Trie()


print("All In Trie:")
for seq in counts.traverse():
    print(seq)

counts.add(['a', 'b', 'c', 'd'])
    
print("All In Trie after:")
for seq in counts.traverse():
    print(seq, counts.get(seq).count)
    
counts.add(['a', 'e', 'c', 'f'])
for seq in counts.traverse():
    print(seq, counts.get(seq).count)

All In Trie:
[]
All In Trie after:
['a', 'b', 'c', 'd'] <built-in method count of list object at 0x7f012a528dc0>
['a', 'b', 'c', 'd'] <built-in method count of list object at 0x7f012a53bb80>
['a', 'e', 'c', 'f'] <built-in method count of list object at 0x7f012a53bb80>


In [8]:

# adding 4-grams
counts.add(['a', 'b', 'c', 'd'])
counts.add(['a', 'e', 'c', 'f'])
counts.add(['a', 'b', 'g', 'h'])
counts.add(['x', 'y', 'z', 'a'])

# setting & getting meta-info
counts.size = 4
print('ngram size:', counts.size)

# testing counts for n-grams of various sizes
tests = [['a'], ['a', 'b'], ['a', 'x'], ['e', 'c', 'f'], ['a', 'e', 'c', 'f']]

# getting counts
for seq in tests:
    print(counts.get(seq).count, seq)

# traversing trie: getting all strings
print("All In Trie:")
for seq in counts.traverse():
    print(seq)

# traversing trie: by prefix (['a'])
print("In Trie for: {}".format(['a']))
for seq in counts.traverse(node=counts.get(['a']), sequence=['a']):
    print(counts.get(seq).count, seq)
    
# getting size of ngram vocabulary
for i in range(counts.size):
    print("{}-gram V: {}".format(i+1, counts.v(size=i+1)))

ngram size: 4
5 ['a']
3 ['a', 'b']
0 ['a', 'x']
0 ['e', 'c', 'f']
2 ['a', 'e', 'c', 'f']
All In Trie:
['a', 'b', 'c', 'd']
['a', 'b', 'g', 'h']
['a', 'e', 'c', 'f']
['x', 'y', 'z', 'a']
In Trie for: ['a']
2 ['a', 'b', 'c', 'd']
1 ['a', 'b', 'g', 'h']
2 ['a', 'e', 'c', 'f']
1-gram V: 6
2-gram V: 7
3-gram V: 8
4-gram V: 8
