# OSTIA algorithm: implementation and tests
The algorithm is taken and insignigicantly modified from Colin de la Higuera (2010).

Transductions are total function, and for this reason, a symbol \* is used to indicate the unknown output. The properties of the unknown symbol \* are:
   * Absorbent: concatenation of \* and another string results in \*;
   * Neutral: longest common prefix (lcp) of A and \* results in lcp(A).
   
<img src="scheme.png" width="400"/>

# Implementation of OSTIA
This section is devoted to implementation of OSTIA. 

### Transducer's template construction

The OSTIA algorithm takes an already constructed transducer as input and fills in the transitions.

The next cells builds a transducer template *T*. It has $6$ attributes: *Q* (list of states), $\Sigma$ (input alphabet), $\Gamma$ (output alphabet), *qe* (initial state), *E* (list of transitions) and $\sigma$ (state outputting function).

In [1]:
from copy import deepcopy


class FST():
    ''' Generic container class for the FST-related objects.
    * Q: list of states;
    * Sigma: input alphabet;
    * Gamma: output alphabet;
    * qe: initial state (usually "");
    * E: list of transitions;
    * stout: state output dictionary.
    '''
    
    def __init__(self, Sigma=None, Gamma=None):
        self.Q = None
        self.Sigma = Sigma
        self.Gamma = Gamma
        self.qe = ""
        self.E = None
        self.stout = None
        
        
    def rewrite(self, w):
        ''' Rewrites the string w with respect to the transducer. '''
        
        if self.Q == None:
            raise ValueError("The transducer needs to be constructed.")
        
        # move through the transducer and write the output
        result = ""
        current_state = ""
        moved = False
        for i in range(len(w)):
            for tr in self.E:
                if tr[0] == current_state and tr[1] == w[i]:
                    result += tr[2]
                    current_state, moved = tr[3], True
                    break
            if moved == False:
                raise ValueError("This string cannot be read by the current transducer.")
                
        # add the final state output
        if self.stout[current_state] != "*":
            result += self.stout[current_state]
            
        return result
                    
        
        
        
def copy_fst(T_orig):
    ''' We need to be able to do a deep copy of FST in order to backtrack
        efficiently when testing if one subtree can be folded into another.
    '''
    T = FST()
    T.Q = deepcopy(T_orig.Q)
    T.Sigma = deepcopy(T_orig.Sigma)
    T.Gamma = deepcopy(T_orig.Gamma)
    T.E = deepcopy(T_orig.E)
    T.stout = deepcopy(T_orig.stout)
    
    return T

### Helper functions

The cell below defines basic functions used in OSTIA such as ``prefix``, ``lcp``, etc.

In [2]:
def prefix(w):
    ''' Returns a list os prefixes of a word. '''
    
    return [w[:i] for i in range(len(w)+1)]



def lcp(*args):
    ''' Finds longest common prefix of unbounded number of strings strings. '''
    
    w = list(set(i for i in args if i != "*"))
    if not w: raise IndexError("At least one non-unknown string needs to be provided.")
    
    result = ""
    n = min([len(x) for x in w])
    for i in range(n):
        if len(set(x[i] for x in w)) == 1: result += w[0][i]
        else: break
    
    return result



def remove_from_prefix(w, pref):
    ''' Removes a substring from the prefix position of another string. '''
    
    if w.startswith(pref): return w[len(pref):]
    elif w == "*": return w
    else: raise ValueError(pref + " is not a prefix of " + w)

### BUILD-PTT

BUILD-PTT builds a Prefix Tree Transducer (PTT) based on the data sample. It instantiates the list of states, creates empty transitions, and sets state outputs to the known output if the state corresponds to its corresponding input, or to the unknown symbol \* otherwise.

In [3]:
def build_ptt(S, Sigma, Gamma):
    """ Builds a prefix tree transducer based on the data sample. """
    
    # build a template for the transducer
    T = FST(Sigma, Gamma)
    
    # fill in the states of the transducer
    T.Q = []
    for i in S:
        for j in prefix(i[0]):
            if j not in T.Q:
                T.Q.append(j)
                
    # fill in the empty transitions
    T.E = []
    for i in T.Q:
        if len(i) >= 1:
            T.E.append([i[:-1], i[-1], "", i])
            
    # fill in state outputs
    T.stout = {}
    for i in T.Q:
        for j in S:
            if i == j[0]:
                T.stout[i] = j[1]
        if i not in T.stout:
            T.stout[i] = "*"
    
    return T

### ONWARD-PTT

Takes the previously created PTT ``T`` and makes it onward by pushing every common prefix of every output (state or transitional) closer to the root.

In [4]:
def onward_ptt(T, q, u):
    """ Makes the PTT onward. """
    
    # going to the leaves of the transducer and trying to push as much as possible
    for tr in T.E:
        if tr[0] == q:
            T, q1, w = onward_ptt(T, tr[3], tr[1])
            if tr[2] != "*":
                tr[2] = tr[2] + w
    
    # find the part of the outputs that can be pushed towards the front
    before_pushing = [T.stout[q]]
    for tr in T.E:
        if tr[0] == q:
            before_pushing.append(tr[2])
    f = lcp(*before_pushing)
    
    # pushing the lcp towards the front
    if f != "":
        for tr in T.E:
            if tr[0] == q:
                tr[2] = remove_from_prefix(tr[2], f)
        T.stout[q] = remove_from_prefix(T.stout[q], f)
        
    return T, q, f

### OSTIA-OUTPUTS
This funciton compares two string and returns the other string if one of them is unknown (returns `apple` given `apple,*`), either of the strings if they are identical (returns `apple` given `apple, apple` or `*` given `*,*`), or False if they are different (returns `False` given `apple,banana`).

In [5]:
def ostia_outputs(w1,w2):
    ''' Compares two strings allowing for unknown. '''
    
    if w1 == "*": return w2
    elif w2 == "*": return w1
    elif w1 == w2: return w2
    else: return False

### OSTIA-PUSHBACK
If a state Q2 is being folded into the state Q1, OSTIA-PUSHBACK pushes the non-common suffix from the transitions and state outputs of these two states firther into the subtree of Q1 and Q2. It is needed to check if the folding can be done successfully.

In [6]:
def ostia_pushback(T_orig, q1, q2, a):
    ''' Moves further the part after lcp of two states. '''
    
    # to avoid rewriting the original transducer
    T = copy_fst(T_orig)
    
    # states where you get if follow a
    q1_goes_to = None
    q2_goes_to = None
    
    # what is being written from this state
    from_q1, from_2 = None, None
    for tr in T.E:
        if tr[0] == q1 and tr[1] == a:
            from_q1 = tr[2]
            q1_goes_to = tr[3]
        if tr[0] == q2 and tr[1] == a:
            from_q2 = tr[2]
            q2_goes_to = tr[3]
    if from_q1 == None or from_q2 == None:
        raise ValueError("One of the states cannot be found.")
    
    # find the part after longest common prefix
    u = lcp(from_q1, from_q2)
    remains_q1 = from_q1[len(u):]
    remains_q2 = from_q2[len(u):]
    
    # assign lcp as current output
    for tr in T.E:
        if tr[0] in [q1, q2] and tr[1] == a:
            tr[2] = u
            
    # find what the next state writes given any other choice
    # and append the common part in it
    for tr in T.E:
        if tr[0] == q1_goes_to:
            tr[2] = remains_q1 + tr[2]
        if tr[0] == q2_goes_to:
            tr[2] = remains_q2 + tr[2]
    
    # append common part to the next state's state output
    T.stout[q1_goes_to] = remains_q1 + T.stout[q1_goes_to]
    T.stout[q2_goes_to] = remains_q2 + T.stout[q2_goes_to]
    
    return T

### OSTIA-MERGE
This function merges the states Q1 and Q2, and tries to fold their subtrees into each other. If it is not possible, OSTIA-MERGE backtracks the transducer to its original state and returns `False`.

In [7]:
def ostia_merge(T_orig, q1, q2):
    ''' Redirects all branches to q2 into q1. '''
    
    # to avoid rewriting the original transducer
    T = copy_fst(T_orig)
    
    # save which transition was changed to revert in case cannot merge the states
    changed = None
    for tr in T.E:
        if tr[3] == q2:
            changed = tr[:]
            tr[3] = q1
            
    # save the state output of the q1 originally
    changed_stout = T.stout[q1]
            
    # check if we can merge the states
    can_do = ostia_fold(T, q1, q2)
    
    # if cannot, revert the change
    if can_do == False:
        for tr in T.E:
            if tr[0] == changed[0] and tr[1] == changed[1] and tr[2] == changed[2]:
                tr[3] = changed[3]
        T.stout[q1] = changed_stout
        return False
    
    # if can, do it
    else: return can_do

### OSTIA-FOLD
This function recursively folds subtrees of Q2 into Q1. If it is possible, it does it, otherwise it reverts the changes and returns `False`.

Folding is not possible if the output of Q1 mismatches with the output of Q2, and it is the case when:
   * for the same symbol, output of Q1 is not a prefix of the output of Q2;
   * state outputs of Q1 and Q2 are different.

In [8]:
def ostia_fold(T_orig, q1, q2):
    ''' Folds recursively subtrees of Q2 into Q1. '''
    
    # to avoid rewriting the original transducer
    T = copy_fst(T_orig)
    
    # compare the state outputs
    w = ostia_outputs(T.stout[q1], T.stout[q2])
    if w == False: return False
    
    # rewrite * in case it's the output of q1
    T.stout[q1] = w

    # look at every possible subtree of q_2
    for a in T.Sigma:
        add_new = False

        for tr_2 in T.E:
            if tr_2[0] == q2 and tr_2[1] == a:
                
                # if the edge exists from q1
                edge_defined = False
                for tr_1 in T.E:
                    if tr_1[0] == q1 and tr_1[1] == a:
                        edge_defined = True
                        
                        # fail if inconsistent with output of q2
                        if tr_1[2] not in prefix(tr_2[2]):
                            return False
                        
                        # move the mismatched suffix of q1 and q2 further
                        T = ostia_pushback(T, q1, q2, a)
                        T = ostia_fold(T, tr_1[3], tr_2[3])
                        if T == False: return False
                        
                # if the edge doesn't exist from q1 yet, add it
                if not edge_defined:
                    add_new = [q1, a, tr_2[2], tr_2[3]]
        
        # if the new transition was constructed, add it to the list of transitions
        if add_new:
            T.E.append(add_new)
    
    return T

### OSTIA-CLEAN
If one just follows OSTIA instructions, the resulting transducer contains non-reachable states because the algorithm doesn't remove them -- they are just not being considered because they're not colored blue or red. The OSTIA-CLEAN function allows to get rid of those states in transitions, state outputs, or in the list of states.

In [9]:
def ostia_clean(T_orig):
    ''' Cleans the resulting transducers by getting rid of the states that were never processed
        (i.e. never colored red or blue) -- those states are not reachable from the initial state.
    '''
    
    # to avoid rewriting the original transducer
    T = copy_fst(T_orig)
    
    # determine which states are reachable, i.e. accessible from the initial state
    reachable_states = [""]
    add = []
    change_made = True
    while change_made == True:
        change_made = False
        for st in reachable_states:
            for tr in T.E:
                if tr[0] == st and tr[3] not in reachable_states and tr[3] not in add:
                    add.append(tr[3])
                    change_made = True

        # break out of the loop if after checking the list once again, no states were added
        if change_made == False:
            break
        else:
            reachable_states.extend(add)
            add = []
            
    # clean the list of transitions
    new_E = []
    for tr in T.E:
        if tr[0] in reachable_states and tr[3] in reachable_states:
            new_E.append(tr)
    T.E = new_E

    # clean the dictionary of state outputs
    new_stout = {}
    for i in T.stout:
        if i in reachable_states:
            new_stout[i] = T.stout[i]
    T.stout = new_stout

    # clean the list of states
    new_Q = [i for i in T.Q if i in reachable_states]
    T.Q = new_Q
    
    return T

### OSTIA
Ostia proceeds as following:
   1. Create the prefix tree transducer from the input sample S using `build_ptt`;
   2. Make that PTT onward using `onward_ptt`;
   3. color the initial state *red* and all the states reachable from the initial one *blue*;
   4. For every blue state try to merge it with any red state using `ostia_merge` and fold its subtree into the red state:
      * if it is possible, do it and remove that blue state from the list of states;
      * if it is not possible, color that blue state red;
   5. Create an updated list of blue states by listing all the uncolored states that are accessible from the red ones;
   6. If the list of blue states is not empty, go to *4*;
   7. Clean the resulting transducer from the unaccessible states;
   8. Return the trimmed transducer.
   
**Arguments**:
   * S: data sample;
   * Sigma: input alphabet;
   * Gamma: output alphabet.

In [10]:
def ostia(S, Sigma, Gamma):
    ''' Finds a subsequential transducer that corresponds to the input sample. '''
    
    # create a template of the onward PTT
    T = build_ptt(S, Sigma, Gamma)
    T = onward_ptt(T, "", "")[0]
    
    # color the nodes
    red = [""]
    blue = [tr[3] for tr in T.E if tr[0] == "" and len(tr[1]) == 1]
    
    # choose a blue state
    while len(blue) != 0:
        blue_state = blue[0]

        # if exists state that we can merge with, do it
        exists = False
        for red_state in red:
            
            # if you already merged that blue state with something, stop
            if exists == True: break
                
            # try to merge these two states
            if ostia_merge(T, red_state, blue_state):
                T = ostia_merge(T, red_state, blue_state)
                exists = True
        
        # if it is not possible, color that blue state red
        if not exists:
            red.append(blue_state)
            
        # if possible, remove the folded state from the list of states
        else:
            T.Q.remove(blue_state)
            del T.stout[blue_state]
            
        # add in blue list other states accessible from the red ones that are not red
        blue = []
        for tr in T.E:
            if tr[0] in red and tr[3] not in red:
                blue.append(tr[3])
    
    # clean the transducer from non-reachable states
    T = ostia_clean(T)
                
    return T

## Experiments with the code

#### Experiment 1.
The sample `S1` shows the transduction corresponding to $a\rightarrow 0$, $b\rightarrow 1$.

In [11]:
S1 = [("ab", "01"), ("aba", "010"), ("aaa", "000"), ("bb", "11"), ("babb", "1011"), ("bbaa", "1100"), ("aa", "00"),
     ("baab", "1001"), ("ba", "10"), ("bba", "110"), ("baa", "100"), ("bab", "101")]
Sigma = ["a", "b"]
Gamma = ["0", "1"]

In [12]:
T1 = ostia(S1, Sigma, Gamma)
print("States:", T1.Q)
print("Transitions:", T1.E)
print("State outputs:", T1.stout)

States: ['']
Transitions: [['', 'a', '0', ''], ['', 'b', '1', '']]
State outputs: {'': ''}


In [13]:
test = ["aba", "bbb", "ababa", "abbaba"]
for w in test:
    print(w, "--->", T1.rewrite(w))

aba ---> 010
bbb ---> 111
ababa ---> 01010
abbaba ---> 011010


#### Experiment 2: from Colin de la Higuera (2010)
The sample `S2` shows the transduction corresponding to $b\rightarrow 1$ and $a\rightarrow 0$ unless $a$ is final, in this case $a\rightarrow 1$.

In [14]:
S2 = [("b", "1"), ("a", "1"), ("ab", "01"), ("abb", "011"), ("bb", "11"), ("aa", "01"), 
     ("aaa", "001"), ("aabaab", "001001"), ("aab", "001"), ("aaba", "0011"), ("aabaa", "00101")]
Sigma = ["a", "b"]
Gamma = ["0", "1"]

In [15]:
T2 = ostia(S2, Sigma, Gamma)
print("States:", T2.Q)
print("Transitions:", T2.E)
print("State outputs:", T2.stout)

States: ['', 'a']
Transitions: [['', 'b', '1', ''], ['', 'a', '', 'a'], ['a', 'b', '01', ''], ['a', 'a', '0', 'a']]
State outputs: {'': '', 'a': '1'}


In [16]:
test = ["aba", "bbb", "ababa", "abbaba"]
for w in test:
    print(w, "--->", T2.rewrite(w))

aba ---> 011
bbb ---> 111
ababa ---> 01011
abbaba ---> 011011
