In [8]:
"""
Builds a character-level text prediction tree based on a corpus that is passed to it.
HOW TO:
1. import tree
2. run `build_tree` method against a corpus (as a big string). It returns a prediction tree.
3. predict continuation for an already typed text via `predict` method. Prediction is word-level, i.e. only
   parts of words should be passed to `predict` method, not phrases.
"""
import string
from typing import Dict, Any, Optional
import re

from nltk.tokenize import word_tokenize
from string import punctuation

In [9]:
class CharTree:
    def __init__(self, data, children=None, count=0, step=1):
        if children is None:
            children = dict()
        self.data: Any = data
        self.children: Dict[Optional[str], CharTree] = children
        self.count: int = count
        self.step: int = step

    @staticmethod
    def build_tree(corpus: str, step=1):
        """
        The main thing to build a CharTree. Pass a corpus as a string and get a chartree as a return value
        `step` allows to build a trie based on character-level ngrams, not unigrams (single chars)
        """
        tree = CharTree(data="", step=step)
        corpus = [word.strip(punctuation) for word in word_tokenize(text)]
        for word in (c for c in corpus if len(c) > 0):
            tree.__build_branch(word)
        return tree

    def __build_branch(self, word: str):
        """
        Function used for recursive creation of chartree branches
        """
        length_of_word = len(word)
        if length_of_word >= self.step:
            seq = word[:self.step]
        else:
            seq = word[:length_of_word]
        self.children[seq] = self.children.get(seq, CharTree(data=seq, step=self.step))
        child: CharTree = self.children[seq]
        child.count += 1
        if length_of_word > self.step:
            child.__build_branch(word[self.step:])
        elif length_of_word <= self.step:
            child.children[None] = child.children.get(None, TreeLeaf())
            child.children[None].count += 1

    def _get_matching_subtree(self, typed_text) -> Optional['CharTree']:
        """
        self-explainatory name. Given typed text, we try to find the already determined node in the tree.
        """
        try:
            length = len(typed_text)
            if length >= self.step:
                child = self.children[typed_text[:self.step]]
                match = child._get_matching_subtree(typed_text[self.step:])
            else:
                match = self.children[typed_text[:length]]
            return match
        except (KeyError, IndexError):
            return self

    def __get_most_probable_continuation(self) -> str:
        """
        Go down the tree through the nodes with higher count (most probablee ones), collecting the chars they represent
        along the way
        """
        (most_probable_letter, child) = max(self.children.items(), key=lambda kv: kv[1].count)
        if most_probable_letter is None:
            return ''
        else:
            return most_probable_letter + child.__get_most_probable_continuation()

    def printout(self, level=0):
        """
        A simple visualisation of any CharTree object.
        Left number represents the level (from 0 to N)
        The number next to a letter represents how often it was encountered in this position
        Example:
        0| l-5
        1|   a-5
        2|     m-5
        3|       p-3
        3|       a-2
        This makes "lama" LESS probable than "lamp", given input "lam"
        """
        for key in self.children.keys():
            child = self.children[key]
            print(level, '|', '\t' * level, f"{key}-{child.count}", sep='')
            child.printout(level + 1)

    def predict(self, typed_text: str):
        __lowered_text = CharTree._format_input_data(typed_text.lower())
        match = self._get_matching_subtree(__lowered_text)
        if match is None:
            return None
        continuation: str = match.__get_most_probable_continuation()
        return typed_text + continuation


class TreeLeaf(CharTree):

    def __init__(self):
        super(TreeLeaf, self).__init__(data=None)

In [3]:
with open("../dictation_text.txt") as f:
    text = f.read()

In [15]:
CharTree.build_tree(text)

<__main__.CharTree at 0x7ff4c67fba90>