<a href="https://colab.research.google.com/github/anjali-rgpt/Autocomplete/blob/master/Data_Structure_Radix_Tree_with_Probabilities.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import requests
from tqdm import tqdm_notebook as tqdm
import sys, inspect
import numpy as np


In [None]:
# Function to calculate size of Tree Object
# Adapted from: https://github.com/bosswissam/pysize
def get_size(obj, seen=None):
    """Recursively finds size of objects in bytes"""
    size = sys.getsizeof(obj)
    if seen is None:
        seen = set()
    obj_id = id(obj)
    if obj_id in seen:
        return 0
    # Important mark as seen *before* entering recursion to gracefully handle
    # self-referential objects
    seen.add(obj_id)
    if hasattr(obj, '__dict__'):
        for cls in obj.__class__.__mro__:
            if '__dict__' in cls.__dict__:
                d = cls.__dict__['__dict__']
                if inspect.isgetsetdescriptor(d) or inspect.ismemberdescriptor(d):
                    size += get_size(obj.__dict__, seen)
                break
    if isinstance(obj, dict):
        size += sum((get_size(v, seen) for v in obj.values()))
        size += sum((get_size(k, seen) for k in obj.keys()))
    elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)):
        size += sum((get_size(i, seen) for i in obj))
        
    if hasattr(obj, '__slots__'): # can have __slots__ with __dict__
        size += sum(get_size(getattr(obj, s), seen) for s in obj.__slots__ if hasattr(obj, s))
        
    return size

In [None]:

# Function to make a tree, when given text file
def makeTree(text):
  words = text.split()
  tree = Radix()
  with tqdm(total=len(words)) as pbar:
    for word in words:
      tree.addWord(word)
      pbar.update(1)
  print('Tree built with {} words'.format(len(words)))
  return tree



In [None]:
# Same as makeTree, but uses iterative method
def fastTree(text):
  words = text.split()
  tree = Radix()
  with tqdm(total=len(words)) as pbar:
    for word in words:
      tree.fastAdd(word)
      pbar.update(1)
  print('Tree built with {} words'.format(len(words)))
  return tree



In [None]:
# given text corpus, find the number of times 
# a node is visited while it is traversed
def train(tree, corpus):
  for word in corpus.split():
    tree.traverse(word)

In [None]:

# Text corpora to build tree:
s = requests.get('https://raw.githubusercontent.com/dwyl/english-words/master/words_alpha.txt')
r = requests.get('https://raw.githubusercontent.com/first20hours/google-10000-english/master/google-10000-english.txt')
g = requests.get('https://raw.githubusercontent.com/first20hours/google-10000-english/master/20k.txt')


hugetext = s.text
gianttext = g.text
bigtext = r.text
smoltext = "hello hell helmet henry helm"

In [None]:
# Radix class implementation


class Radix:
  
  __slots__ = ("data", "endword", "count")

  def __init__(self):
    self.data = dict()
    self.endword = list()
    self.count = 1

  def addWord(self, word, origin = ""):
    node = self

    if len(word) == 0 and origin != "" and origin not in node.endword:
      node.endword.append(origin)
      # print("End of word ", origin)
      return

    if origin == "":
      origin += word

    # print('\nGot word:', word)
    # print("Origin:", origin)

    substrings = [word[:i] for i in range(len(word), 0, -1)]
    # print('Substrings:', substrings)
    
    for sub in substrings:
      if sub in node.data:
        # print('Moving to node:', sub)
        node = node.data[sub]
        return node.addWord(word[len(sub):], origin)
      for el in node.data:  
        if el.startswith(sub):
          # print('Creating new node:', sub)
          node.data[sub] = Radix()
          # print('Patching tree with data from old element:', el)
          node.data[sub].data[el[len(sub):]] = node.data[el]
          del node.data[el]
          # print('Moving to node:', sub)
          node = node.data[sub]
          return node.addWord(word[len(sub):], origin)
      if sub == substrings[-1]:
        # print('Adding word:', word)
        node.data[word] = Radix()
        if origin not in node.data[word].endword:
          node.data[word].endword.append(origin)
        # print("End of word", origin)

  # Iterative version of addWord. 
  def fastAdd(self, word):
    node = self
    maxlen = len(word)
    i, j = 0, maxlen
    while i < maxlen:
      sub = word[i:j]
      if sub in node.data:
        #print('Moving to node:', word[i:j])
        node = node.data[sub]
        i, j = j, maxlen
        continue
      else:  
        for el in node.data:
          if el.startswith(sub):
            #print('Creating new node:', word[i:j])
            node.data[sub] = Radix()
            #print('Patching tree with data from old element:', el)
            node.data[sub].data[el[len(sub):]] = node.data[el]
            del node.data[el]
            #print('Moving to node:', word[i:j])
            node = node.data[sub]
            i, j = j, maxlen
            break
      if j > i+1:
        j -= 1
        continue
      elif j == i+1:
        # print('Adding word:', word[i:])
        node.data[word[i:]] = Radix()
        if word not in node.endword:
          node.data[word[i:]].endword.append(word)
        # print("End of word", word)
        break
    # else:
    if word not in node.endword:
      node.endword.append(word)
    # print("End of word", word)


  # Returns True if word in vocabulary, else returns None
  def lookup(self, word, origin=""):
    if origin == "":
      origin += word
    if len(word) == 0 and origin in self.endword:
      # print('Word found!')
      return True
    node = self
    substrings = [word[:i] for i in range(len(word), 0, -1)]
    for sub in substrings:
      if sub in node.data:
        # print('Found substring:', sub)
        node = node.data[sub]
        return node.lookup(word[len(sub):], origin)
    # print('Word not found :(')
    return None

  # Iterative approach to lookup. Lesser overhead than recursion 
  def fastlookup(self, word):
    # print('Got word:', word)
    node = self
    i, j = 0, len(word)
    while j > i:
      if word[i:j] in node.data:
        node = node.data[word[i:j]]
        # print('Substring matched:', word[i:j], i, j)
        i = j
        j = len(word)
      else:
        # print('Match failed:', word[i:j], i, j)
        j -= 1
    if i == len(word) and word in node.endword:
      # print('Word found!')
      return True
    else:
      # print('Word not found :(')
      return False

  # Given a word, return node at which it ends
  def traverse(self, word):
    node = self
    node.count += 1
    maxlen = len(word)
    i, j = 0, maxlen
    temp = ''
    while i < maxlen and j > i:
      sub = word[i:j]
      try:
        node = node.data[sub]
        #print(word[i:j],"->",end="")
        node.count += 1
        temp += sub
        i, j = j, maxlen
      except KeyError: 
        j -= 1
    return node, temp
    
 
  # Print the entire tree recursively
  def display(self):
    node = self
    if len(node.data) == 0:
      return
    print('Data at node:', list(node.data.keys()))
    #print('Endings of words:',list(node.endword.keys()))
    for el in node.data:
      node.data[el].display()
    return

  # override [] access for convenience
  def __getitem__(self, word):
    return self.data[word]


In [None]:
def sample(preds, temperature=1.0):
    # helper function to sample an index from a probability array
    preds = np.asarray(preds).astype('float64')
    preds = np.log(preds) / temperature
    exp_preds = np.exp(preds)
    preds = exp_preds / np.sum(exp_preds)
    probas = np.random.multinomial(1, preds, 1)
    return np.argmax(probas)



In [None]:

def predict(tree,word): 
  node, traversed = tree.traverse(word)
  predicted = traversed
  suff = word[len(traversed):]
  total = sum((node.data[el].count for el in node.data)) 
  paths = tuple(node.data.keys())                        
  #print(paths)
  #print("Suffix:",suff)
  #print("Total traversals:",total)
  
  prob = tuple(node.data[el].count/total for el in node.data)
  temp = []
  if len(suff) != 0:                                     
    for el in paths:
      if el.startswith(suff):                            
        temp.append(el)                                  
    if len(temp) != 0:                                   
      paths = tuple(temp)                                 
      prob = tuple(node.data[el].count/total for el in node.data if el in paths)

  # print(paths)
  # print("Total traversals:",total)
  # print("Probabilities:",prob)
  if len(node.data)==0:
    return word
  n= 1/len(node.data)
  
  while len(prob) != 0 and max(prob) > 0.01 and predicted not in node.endword:
    #ind = prob.index(max(prob))
    ind = sample(prob)
    node = node.data[paths[ind]]
    predicted += paths[ind]
    #print("Predicted so far:",predicted)
    total = sum((node.data[el].count for el in node.data))
    # print("Total traversals:",total)
    paths = tuple(node.data.keys())
    #print(paths)
    prob = tuple(n*node.data[el].count / total for el in node.data)
    #print("Probabilities:",prob)
  # print("Final Predicted:",predicted)
  return predicted



In [None]:

def nodecount(node):
  counter = 0
  for el in node.data:
    counter += 1 + nodecount(node.data[el])
  return counter


In [None]:
corpora = smoltext  #pick the corpus

In [None]:
# rtree = makeTree(corpora)
ftree = fastTree(corpora)

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))


Tree built with 5 words


In [None]:
# Test tree methods here

# Test the existence of all the words in the corpus
def testlook(tree, corpus):
  for word in corpus.split():
    if not tree.fastlookup(word):
      print('Test failed :(')
      print('Couldn\'t find word:', word)
      return 
  print('Success!')
  return 

# testlook(rtree, corpora)
testlook(ftree, corpora)




# train over some text corpus
# train(ftree,corpora)

# ftree.display()

Success!
Success!


In [None]:
def testraverse(tree, corpora):
  for word in corpora.split():
    if not tree.traverse(word):
      print('Test failed :(')
      print('Couldn\'t traverse:', word)
      return False
  print('Success!')
  return

# testraverse(rtree, corpora)
testraverse(ftree, corpora)

In [None]:
# Analyse size of data

s = 0
for word in corpora.split():
  s += sys.getsizeof(word)

print('Text corpora size:', s/1e+6, 'megabytes\n')

def analyse(tree):
  size = get_size(tree)
  print('Tree:', size/1e+6, 'megabytes')
  print('Increase in size:', (size - s)/1e+6, 'megabytes\n')
  return size

# rsize = analyse(rtree)
fsize = analyse(ftree)
# print('Difference:', fsize-rsize, 'bytes')

Text corpora size: 0.000269 megabytes

Tree: 0.003909 megabytes
Increase in size: 0.00364 megabytes



In [None]:
nodecount(ftree)

7

In [None]:

print(smoltext)
print('Predicted:', predict(ftree,'horro'))
# t = ftree.traverse('horro') 
# print('horro'[len(t[1]):])

hello hell helmet henry helm
Predicted: horror


In [None]:
def score(tree, corpora): 
  ctr = 0
  total = 0
  with tqdm(total=len(corpora.split())) as pbar:
    for word in corpora.split():
      pbar.update(1)
      subs = [word[:i] for i in range(1,len(word))]
      for sub in subs:
        p = predict(ftree, sub)
        total+=1
        if word == p:
          # print(sub, word)
          ctr += 1

  print('Total:', total)
  print('Correct:', ctr)
  print('Accuracy:', 100 * ctr/total)

In [None]:
import nltk
nltk.download()

# choose d
# then type 'book'

NLTK Downloader
---------------------------------------------------------------------------
    d) Download   l) List    u) Update   c) Config   h) Help   q) Quit
---------------------------------------------------------------------------
Downloader> q


True

In [None]:
from nltk.book import *
texts = [text1, text2, text3, text4, text5, text6, text7, text8, text9]

ftree = Radix()

print('Initialising tree...')
# for text in text1:
for word in set(text1):
  ftree.fastAdd(word)

print('Pre-training scores:')
score(ftree, ' '.join(text1))

print('\nTraining data...')
with tqdm(total=len(texts)) as pbar:
  for text in texts:
    pbar.update(1)
    train(ftree, ' '.join(text))

print('Post-training scores:')
score(ftree, ' '.join(text1))

Initialising tree...
Pre-training scores:


HBox(children=(IntProgress(value=0, max=260819), HTML(value='')))




KeyboardInterrupt: ignored

In [None]:
# Adapted from Justin Peel's implementation:
# https://stackoverflow.com/a/2412468/8986124

class Patricia():
    def __init__(self):
        self._data = {}

    def addWord(self, word):
        data = self._data
        i = 0
        while 1:
            try:
                node = data[word[i:i+1]]
            except KeyError:
                if data:
                    data[word[i:i+1]] = [word[i+1:],{}]
                else:
                    if word[i:i+1] == '':
                        return
                    else:
                        if i != 0:
                            data[''] = ['',{}]
                        data[word[i:i+1]] = [word[i+1:],{}]
                return

            i += 1
            if word.startswith(node[0],i):
                if len(word[i:]) == len(node[0]):
                    if node[1]:
                        try:
                            node[1]['']
                        except KeyError:
                            data = node[1]
                            data[''] = ['',{}]
                    return
                else:
                    i += len(node[0])
                    data = node[1]
            else:
                ii = i
                j = 0
                while ii != len(word) and j != len(node[0]) and \
                      word[ii:ii+1] == node[0][j:j+1]:
                    ii += 1
                    j += 1
                tmpdata = {}
                tmpdata[node[0][j:j+1]] = [node[0][j+1:],node[1]]
                tmpdata[word[ii:ii+1]] = [word[ii+1:],{}]
                data[word[i-1:i]] = [node[0][:j],tmpdata]
                return

    def isWord(self,word):
        data = self._data
        i = 0
        while 1:
            try:
                node = data[word[i:i+1]]
            except KeyError:
                return False
            i += 1
            if word.startswith(node[0],i):
                if len(word[i:]) == len(node[0]):
                    if node[1]:
                        try:
                            node[1]['']
                        except KeyError:
                            return False
                    return True
                else:
                    i += len(node[0])
                    data = node[1]
            else:
                return False

    def isPrefix(self,word):
        data = self._data
        i = 0
        wordlen = len(word)
        while 1:
            try:
                node = data[word[i:i+1]]
            except KeyError:
                return False
            i += 1
            if word.startswith(node[0][:wordlen-i],i):
                if wordlen - i > len(node[0]):
                    i += len(node[0])
                    data = node[1]
                else:
                    return True
            else:
                return False

    def removeWord(self,word):
        data = self._data
        i = 0
        while 1:
            try:
                node = data[word[i:i+1]]
            except KeyError:
                print ("Word is not in trie.")
                return
            i += 1
            if word.startswith(node[0],i):
                if len(word[i:]) == len(node[0]):
                    if node[1]:
                        try:
                            node[1]['']
                            node[1].pop('')
                        except KeyError:
                            print ("Word is not in trie.")
                        return
                    data.pop(word[i-1:i])
                    return
                else:
                    i += len(node[0])
                    data = node[1]
            else:
                print ("Word is not in trie.")
                return

    __getitem__ = isWord

In [None]:
def makePatricia(text):
  words = text.split()
  tree = Patricia()
  with tqdm(total=len(words)) as pbar:
    for word in words:
      tree.addWord(word)
      pbar.update(1)
  print('Tree built with {} words'.format(len(words)))
  return tree

x = makePatricia(corpora)
psize = get_size(x)/1e+6
print(psize, 'megabytes')
x._data