In [1]:
# necessary imports
import string

# Tokenization with transducers

In this notebook, I explore a way of extracting a FST for tokenization using the assumption that biases the input alphabet.

I assume that there are two categories of symbols within the alphabet: *static* symbols and *active* symbols. 
The static symbols correspond to themselves on the side of output, such as in `a` $\rightarrow$ `a`.
The behavior of the active symbols needs to be determined: as in the pair `("stop.", "stop_.")` the output of `.` is `.` $\rightarrow$ `_.`. 
However, in `("mr. Bean", "mr. Bean")`, `.` stays the same without introducing the space before it. 

I start by initializing a Prefix Tree Transducer (PTT) for the stretches of active symbols while ignoring the stretches of the static ones -- by definition, they are the same on both input and output sides. 
Every occurrence of the active symbol in the input introduces a new state.
Also, a special symbol marks the end of line, therefore helping us to differentiate between the final and non-final active symbols stretches.
At this stage, we only have PTT for the active symbols parts of the input strings: all the static stretches are ignored.

Then I distribute the corresponding parts of the output by writing the identity symbols as themselves. If given a stretch of the moving symbols $w_{i_1}\dots w_{i_n}$ on the input side, check which part of the corresponding output stretch $w_{o}$ in the state output of the state $w_{i_n}$. **_Any inconsistency when doing this step indicates that either the data not deterministic, or that the initial assumption about the identity symbols is wrong._**

Then I am onwarding the PTT following the ONWARD-PTT algorithm from Colin de la Higuera (2010).
Like this we are getting a PTT where every output is written as soon as possible.

Last step is to rebuild the branches of the transducer in the following manner:
   1. Collect a list of leave nodes of the current PTT.
   2. For every leaf node, determine if it is final (ends with the EOL character) or not.
   3. Merge all final leaves together.
   4. Merge all non-final leaves with the initial state of the FST.
   5. Add reflexive loops containing static symbols on the initial state of the FST.

Like this, we get the transducer that writes as much as possible as soon as possible. However, the unfortunate side of the things is that this assumption is not correct for the tokenization task.

## Results

In the testing section (end of the notebook), I show the successful application of this idea on the simplified tokenization task.

However, the overall result is not satisfying -- the initial assumption is not correct, because there is no simple way (as of now) to detect if there are *identity symbols* for this task of tokenization: substrings of the alphabetical symbols can introduce punctuations, such as `haven't` $\rightarrow$ `have not`, `cannot` $\rightarrow$ `can not`, etc.

**For quick results, go directly in the end of the notebook.**

## Part 0. Getting the helper functions ready
Here, I wrote a couple of functions that are needed in different parts of the code and are not sufficiently complicated by themselves.
  1. *lcp* returns longest common prefix of any number of input strings; used when onwarding the PTT;
  2. *prefix* returns all prefixes of the input string; used when building the PTT;
  3. *remove_from_prefix* removes the given prefix from the given string; used when onwarding the PTT;
  4. *alphabetize* returns the list of symbols used in the input sample; required to build the FST.

In [2]:
def lcp(*args):
    ''' Finds longest common prefix of unbounded number of strings. '''
    
    w = list(set(i for i in args if i != "*"))
    if not w: raise IndexError("At least one non-unknown string needs to be provided.")
    
    result = ""
    n = min([len(x) for x in w])
    for i in range(n):
        if len(set(x[i] for x in w)) == 1: result += w[0][i]
        else: break
    
    return result


def prefix(w):
    ''' Returns a list os prefixes of a word. '''
    return [w[:i] for i in range(len(w)+1)]


def remove_from_prefix(w, pref):
    ''' Removes a substring from the prefix position of another string. '''
    
    if w.startswith(pref): return w[len(pref):]
    elif w == "*": return w
    else: raise ValueError(pref + " is not a prefix of " + w)
        

def alphabetize(sample):
    """ Returns a list of symbols that are used in the input.
        side of the sample.
    """
    
    symbols = []
    for i in sample:
        for j in i[0]:
            if j not in symbols:
                symbols.append(j)
    return symbols

## Part 1. Getting the template of the FST object
First, we create the FST object and define the *rewrite* method for this class that takes a string as an input and rewrites if by going through the transitions of the current FST. The FST has the following attributes:
  1. *Sigma*: all symbols that are used in the data;
  2. *Alpha*: the static symbols, they are always rewritten only by themselves;
  3. *Beta*: the active symbols, their behavior needs to be determined;
  4. *final*: symbol that indicates the end of the string;
  5. *idk*: symbol that is being used to indicate unknown state output;
  6. *Q*: list of states;
  7. *E*: list of transitions;
  8. *stout*: dictionary of state outputs, needed for onwarding the PTT.
  
**Warning:** the *final* and *idk* symbols must not appear in *Sigma*.

In [3]:
class FST():
    ''' Generic container class for the FST-related objects.
    * Sigma: all symbols of the alphabet;
    * Alpha: static symbols;
    * Beta: active symbols;
    * final: the EOS symbol;
    * idk: the symbols for unknown state output;
    * Q: list of states;
    * E: list of transitions;
    * stout: state output dictionary (not filled until later).
    '''
    
    def __init__(self, Sigma, Alpha, final="<", idk="*"):
        """ Initializes the FST object. """
        
        self.Sigma = Sigma
        self.Alpha = Alpha
        self.Beta = [i for i in self.Sigma if i not in self.Alpha]
        self.final = final
        self.idk = idk
        self.Q = None
        self.E = None
        self.stout = None
        
        
    def rewrite(self, word):
        """ Rewrites the given word with respect to
            the learned transductions. """
 
        # if there are no states, there is no transducer.
        if self.Q == None:
            raise ValueError("The transducer is not constructed.")
            
        current_state = ""
        write = ""
        for s in word+T.final:
            for tr in T.E:
                if tr[0] == current_state and tr[1] == s:
                    write += tr[2]
                    current_state = tr[3]
        return write

## Part 2: Finding the static symbols stretches
We can find the substrings, or stretches, of the active symbols based on their location with respect to the stretches of the static symbols: we know that the static symbols stay the same in the output side. First step is then to create a function that returns all stretches of active symbols stretches from the word.

In [4]:
def find_static_stretch(T, w):
    """ Finds the static symbols stretches in the input string w. """
    
    w = "%r"%w
    stretches = []
    current = ""
    
    i = 0
    while i < len(w):
        if w[i] in T.Alpha:
            current += w[i]
            
        if w[i] in T.Beta or w[i] == T.final or i == len(w)-1:
            if current != "":
                stretches.append(current)
                current = ""
        i += 1
    
    return stretches

## Part 3: Find the active stretches mapping
Here, we need to understand how the active symbols stretches are changed.
We can find the static stretches in the output side of the data sample, and like that we can understand which sequence of the active symbols in the output corresponds to which input sequence.

In [5]:
def find_mappings(T, pair):
    """ Finds which stretches of active symbols in the output correspond
        to which ones in the input.
    """
    
    # determine "points of stability", or static symbol stretches
    static_stretches = find_static_stretch(T, pair[0])
    
    mod_inp = pair[0]+T.final
    mod_out = pair[1]
    
    inp_active = []
    out_active = []
    
    # use the list of static stretches as a stack
    while static_stretches:
        
        # find at which position we find the first static stretch
        index_inp = mod_inp.find(static_stretches[0])
        index_out = mod_out.find(static_stretches[0])
        
        # if on the input side the indes is not 0, it means that 
        # the active stretch precedes it: record and delete it
        if index_inp != 0:
            inp_active.append(mod_inp[:index_inp])
            out_active.append(mod_out[:index_out])
            mod_inp = mod_inp[len(inp_active[-1]):]
            mod_out = mod_out[len(out_active[-1]):]
        
        # remove the active stretch to start over
        mod_inp = mod_inp[len(static_stretches[0]):]
        mod_out = mod_out[len(static_stretches[0]):]
        del static_stretches[0]
        
    # if you are out of static stretches, check that there are no 
    # remaining active symbols
    if len(mod_inp) != 0:
        inp_active.append(mod_inp)
        out_active.append(mod_out)
        
    assert len(inp_active) == len(out_active)
        
    return list(zip(inp_active, out_active))

## Part 4: Building the PTT template
Similarly to de la Higuera (2010), I build the PTT, but only based on the active symbols correspondences. Note that at this step, if we would not use the final symbol (`<` by default), we would lose the information about which stretches are the final ones in the original string, and which ones are not.

In [6]:
def build_ptt(T, S):
    
    # assign the state outputs
    T.stout = {}
    for pair in S:
        corr = list(set(find_mappings(T, pair)))
        for c in corr:
            #if c[0] == "'": print(corr, "\n", pair, "\n", "$"+c[1]+"$", "\n")
            if c[0] not in T.stout:
                T.stout[c[0]] = c[1]
            elif c[0] in T.stout and T.stout[c[0]] != c[1]:
                #print(pair, "\n", "$"+c[0]+"$", "$"+T.stout[c[0]]+"$", "$"+c[1]+"$")
                raise ValueError("Inconsistent input sample.")
                
    # create a list of states
    T.Q = []
    for i in T.stout:
        pref = prefix(i)
        for p in pref:
            if p not in T.Q:
                T.Q.append(p)
                
    # assign empty transitions to the leaves
    T.E = []
    for i in T.Q:
        if len(i) >= 1:
            T.E.append([i[:-1], i[-1], "", i])
            
    # add the unknown state outputs
    for s in T.Q:
        if s not in T.stout:
            T.stout[s] = T.idk
            
    return T

## Part 5. Onwarding the PTT
I follow de la Higuera (2010) in the way the current PTT is being onwarded. The algorithm takes the previously created PTT ``T`` and makes it onward by pushing every common prefix of every output (state or transitional) closer to the root.

In [7]:
def onward_ptt(T, q, u):
    """ Makes the PTT onward. """
    
    # proceed as deep as possible
    for tr in T.E:
        if tr[0] == q:
            T, qx, w = onward_ptt(T, tr[3], tr[1])
            if tr[2] != "*":
                tr[2] += w
                  
    # find lcp of all ways of leaving state 1 or stopping in it
    t = [tr[2] for tr in T.E if tr[0] == q]
    f = lcp(T.stout[q], *t)
    
    # remove from the prefix unless it's the initial state
    if f != "" and q != "":
        for tr in T.E:
            if tr[0] == q:
                tr[2] = remove_from_prefix(tr[2], f)
        T.stout[q] = remove_from_prefix(T.stout[q], f)
                
    return T, q, f

## Part 6: Redirect branches
Now that the active symbols-based PTT is built and onward, we can do 3 following changes:
  1. All non-final leaves should loop back into the initial states;
  2. All final leaves need to be unified in the single final state;
  3. Reflexive transitions for the static symbols need to be added in the final state.

In [8]:
def redirect_branches(T):
    """ Finalizes the transducer by redirecting the leaf transitions. """

    nf, f = [], []
    
    # find all leaf nodes
    for s in T.Q:
        appears = False
        for tr in T.E:
            if tr[0] == s:
                appears = True
                break
                
        # detect if the leaf state is final or not
        if appears == False:
            if s.endswith(T.final):
                f.append(s)
            else:
                nf.append(s)

    for tr in T.E:
        # redirect non-final leaves into the initial state
        if tr[3] in nf:
            nf.remove(tr[3])
            tr[3] = ""
            
        # merge all the final leaves into a single one
        elif tr[3] in f:
            f.remove(tr[3])
            T.Q.remove(tr[3])
            tr[3] = T.final
            if T.final not in T.Q:
                T.Q.append(T.final)
    
    # add reflexive transitions for the static symbols
    for s in T.Alpha:
        T.E.append(["", s, s, ""])
    
    # kill the stout: not needed anymore
    T.stout = None
                
    return T
            

## Part 7: Putting it all together
In the next cell, I put all of the function written above together.

In [9]:
def fstok(S, Alpha):
    """ Algoritm that lears the tokenization rules. """
    
    # provide the alphabets and initialize the FST
    Sigma = alphabetize(S)
    T = FST(Sigma, Alpha)
    
    # build the active part of the PTT
    T = build_ptt(T, S)
    T = onward_ptt(T, "", "")[0]
    
    # redirect the branches of the PTT
    T = redirect_branches(T)
    
    return T

# Testing: toy data sample

In this section, I test the proposed algorithm given the toy data.

In [10]:
S = [("Hello, Jon.", "Hello , Jon ."),
     ("Because.", "Because ."),
     ("Come in, dear friends.", "Come in , dear friends ."),
     ("Apples, bananas, kiwis.", "Apples , bananas , kiwis ."),
     ("Mr. Bean came home.", "Mr. Bean came home ."),
     ("What are you talking about?", "What are you talking about ?"),
     ("What are you talking about?!", "What are you talking about ?!")
    ]

In [11]:
Alpha = list(string.ascii_lowercase+string.ascii_uppercase)
Alpha.extend([str(i) for i in range(10)])
T = fstok(S, Alpha)
#print("States:", T.Q)
#print("Transitions:", T.E)

In [12]:
test = ["Hello, dear.", "Mr. Bean is nice.", 
        "Apples, bananas, and oranges are tasty.", "So what?!"]
for i in test:
    print(i, "--->", T.rewrite(i))

Hello, dear. ---> Hello ,  dear .
Mr. Bean is nice. ---> Mr. Bean is nice .
Apples, bananas, and oranges are tasty. ---> Apples ,  bananas ,  and oranges are tasty .
So what?! ---> So what ?!


For a large corpus, this way needs adjustments. In the WSJ data, the assumption that the alphabetical symbols are always rewritten as themselves is not correct.

For example, in the pair (`'Andy'`, `'Andy'`), the rule is `'` $\rightarrow$ `'`. However, in (`"Andy's"`, `"Andy 's"`), it is `'` $\rightarrow$ `_'`. Therefore, in order to account for this situation, we need to check if `s` follows the hyphen, therefore `s` will decide the output and not just `'`.

Removing `s` from _Alpha_ does not really save the situation, because then the same problem appears with the other sentence. In the pair (`we're`, `we 're`), it is the symbol `r` that decides the behavior of the hyphen.

And removing `r` would not help either: there are other examples.
   * `hadn't` $\rightarrow$ `had n't`
   * `cannot` $\rightarrow$ `can not`
   * $\dots$