# WFST Tutorial
## Introduction to Weighted Finite State Transducers (WFSTs)
Weighted Finite State Transducers (WFSTs) are automata where each transition between states is labeled with an input symbol, an output symbol, and a weight (usually representing cost or probability). These structures are commonly used in speech recognition, text normalization, and other fields where mapping between sequences with a cost function is needed.

In this notebook, we'll explore how to construct your own WFST using the provided `WFST` class. We will go through key methods, explain their purposes, and demonstrate how to implement custom transitions and states.
Let's dive into the core aspects of building a WFST.


## WFST Class Overview
The `WFST` class enables you to create a custom finite state transducer by adding states, transitions, and weights between them. Below is a list of the key methods and how they work:

- `set_start_state(state)`: Sets and adds the start state of the WFST. This is the initial state from which transitions begin.
- `add_state(state)`: Adds an intermediate state. States can represent different stages in the transducer.
- `add_final_state(state)`: Marks the state as a final state, indicating that the transduction can successfully terminate when reaching this state.
- `add_transition(from_state, to_state, input_symbol, output_symbol, weight)`:
    - Adds a transition (or arc) from `from_state` to `to_state` based on the provided input and output symbols, with an associated weight (or cost).
    - This is the core method for building relationships between states.
- `add_epsilon_transition(from_state, to_state)`: Adds a transition between states without requiring an input (epsilon transition).
- `process(input_sequence)`: Processes an input sequence, attempting to traverse the WFST and return the output sequence along with the total weight (or cost) of the path.


## Step-by-Step Example: Creating a Simple WFST
Let's walk through an example of how to create a WFST with a few states and transitions.



In [158]:
graph_input_category = {}

class WFST:
    def __init__(self, title):
        self.category = title
        self.states = {}
        self.start_state = None
        self.final_state = set()

    def set_start_state(self, state):
        self.start_state = state

    def add_state(self, state):
        if state not in self.states:
            self.states[state] = {}

    def add_final_state(self, state):
        self.final_state.add(state)

    def add_transition(self, from_state, to_state, input_symbol, output_symbol, weight=0):
        self.add_state(from_state)
        self.add_state(to_state)
        keys = list(graph_input_category.keys())
        if input_symbol not in keys:
            graph_input_category[input_symbol] = self.category
        if input_symbol not in self.states[from_state]:
            self.states[from_state][input_symbol] = []
        self.states[from_state][input_symbol].append((to_state, output_symbol, weight))

    def add_epsilon_transition(self, from_state, to_state, output_symbol, weight=0):
        self.add_transition(from_state, to_state, '', output_symbol, weight)

    def insert(self, start_state, num_transitions, output_symbol):
        current_state = start_state
        for i in range(num_transitions):
            new_state = current_state + 1
            self.add_epsilon_transition(current_state, new_state, output_symbol)
            current_state = new_state
        self.add_final_state(current_state)

    def process(self, input_sequence): 
        current_output = ''
        current_weight = 0
        next_states = []
        for state in self.states:
            if state not in self.final_state:
                values = list(self.states[state].values())
                next_state = 0
                output_symbol = ''
                weight = 10**10
                for itr in values:
                    for (buffer_next_state, buffer_output_symbol, buffer_weight) in itr:
                        if weight >= buffer_weight:
                            next_state = buffer_next_state
                            output_symbol = buffer_output_symbol
                            weight = buffer_weight
                next_states.append((next_state, current_output + output_symbol, current_weight + weight))
                current_output += output_symbol
                current_weight += weight

        return next_states[len(next_states) - 1]

    def compose(self, other, input):
        result = WFST(self.category + other.category)
        result.set_start_state(0)
        i = 0
        for s1 in self.states:
            i = s1  
            if s1 not in self.final_state:
                for symbol1 in self.states[s1]:
                    for (n1, o1, w1) in self.states[s1][symbol1]:
                        result.add_transition(i, i + 1, symbol1, o1, w1)
        if len(self.states) > len(other.states):
            i = i - len(other.states) + 1
        for s1 in other.states:
            if s1 not in other.final_state:
                if input == []:
                    symbol = None
                elif isinstance(input, list):
                    symbol = input[0]
                else:
                    symbol = input
                if symbol in other.states[s1]:
                    for (next_state, output_symbol, weight) in other.states[s1][symbol]:
                        result.add_transition(i, i + 1, symbol, output_symbol, weight)
                        if isinstance(input, list):
                            input.pop(0)
                elif '' in other.states[s1]:
                    for (next_state, output_symbol, weight) in other.states[s1]['']:
                        result.add_transition(i, i + 1, '', output_symbol, weight)
                        break
            i += 1

        result.add_final_state(i - 1)
        return result
    
    def compose_alt(self, other):
        result = WFST(self.category + other.category)
        result.set_start_state(0)
        i = 0
        for s1 in self.states:
            i = s1  
            if s1 not in self.final_state:
                for symbol1 in self.states[s1]:
                    for (n1, o1, w1) in self.states[s1][symbol1]:
                        result.add_transition(i, i + 1, symbol1, o1, w1)
        if len(self.states) > len(other.states):
            i = i - len(other.states) + 1
        for s1 in other.states:
            if s1 not in other.final_state:
                for symbol1 in other.states[s1]:
                    for (n1, o1, w1) in other.states[s1][symbol1]:
                        result.add_transition(i, i + 1, symbol1, o1, w1)
            i += 1

        result.add_final_state(i - 1)
        return result
    
    def output(self, wfst, input_sequence):
        if not isinstance(wfst, CompositeWFST):
            composite_wfst = CompositeWFST()
            composite_wfst.add_wfst('name', wfst)
        else:
            composite_wfst = wfst
        
        wfst_sequence = []
        for itr in input_sequence:
            wfst = composite_wfst.compose(itr)
            wfst_sequence.append(wfst)

        composite_wfst = wfst_sequence[0]
        for i in range(1, len(wfst_sequence)):
            composite_wfst= composite_wfst.compose_alt(wfst_sequence[i])
        print(input_sequence)
        result = composite_wfst.process(input_sequence)

        return result


class CompositeWFST:
    def __init__(self):
        self.wfsts = {}
    
    def add_wfst(self, key, wfst):
        self.wfsts[key] = wfst
    
    def compose(self, input_sequence):
        if not input_sequence:
            return []

        wfst = self.wfsts.get(graph_input_category.get(input_sequence[0]))
        composed_wfst = WFST(graph_input_category.get(input_sequence[0]))
        composed_wfst.set_start_state(0)
        for i in range(len(wfst.states) - 1):
            if input_sequence[0] in wfst.states[i]:
                for (next_state, output_symbol, weight) in wfst.states[i][input_sequence[0]]:
                    composed_wfst.add_transition(i, next_state, input_sequence[0], output_symbol, weight)
            elif '' in wfst.states[i]:
                for (next_state, output_symbol, weight) in wfst.states[i]['']:
                    composed_wfst.add_transition(i, next_state, '', output_symbol, weight)
                    break

        if not composed_wfst:
            return []

        for symbol in input_sequence[1:]:
            next_wfst = self.wfsts.get(graph_input_category.get(symbol))
            if next_wfst:
                composed_wfst = composed_wfst.compose(next_wfst, symbol)
            else:
                return []

        return composed_wfst
    
    def output(self, composite_wfst, input_sequence):
        wfst_sequence = []
        for itr in input_sequence:
            wfst = composite_wfst.compose(itr)
            wfst_sequence.append(wfst)

        composite_wfst = wfst_sequence[0]
        for i in range(1, len(wfst_sequence)):
            composite_wfst= composite_wfst.compose_alt(wfst_sequence[i])
        result = composite_wfst.process(input_sequence)

        return result

### Step 1: Define the WFST


In [159]:
wfst = WFST('name')

This creates an empty WFST. Now we need to add states and transitions.

### Step 2: Set the Start State

In [160]:
wfst.set_start_state(0)

Here, we set the start state to `0`. States are represented by integers.

### Step 3: Add States and Transitions
Let's add some intermediary states and transitions between them.

In [167]:
wfst.add_state(1)
wfst.add_state(2)
wfst.add_transition(0, 1, 'a', 'x', 1.5)
wfst.add_transition(1, 2, 'b', 'y', 1)

This code adds two intermediary states (1 and 2) and connects them with transitions. The first transition maps input `a` to output `x` with a weight of `0.5`, and the second transition maps `b` to `y` with a weight of `1.0`.

### Step 4: Define a Final State

In [168]:
wfst.add_final_state(2)

Marking state `2` as the final state means that if the WFST reaches this state, it successfully processes an input sequence.

### Step 5: Process an Input Sequence
Now let's process an input sequence to see how it works.

In [169]:
input_sequence = ['a', 'b']
output_sequence = wfst.output(wfst, input_sequence)
print(f'Output: {output_sequence[1]}, Weight: {output_sequence[2]}')

['a', 'b']
Output: xy, Weight: 2.5


This processes the input sequence `['a', 'b']` through the WFST and prints the output sequence and the total weight (cost) of the transitions.

You can now build your own WFST by following similar steps, adding states, transitions, and processing sequences.

## Advanced Example: Using Epsilon Transitions
Sometimes, it's necessary to have transitions that don't consume any input (epsilon transitions).
Let's add an epsilon transition to our previous example.

In [164]:
wfst.add_epsilon_transition(2, 0, 'c')

This adds a transition from state `2` back to state `0` without consuming any input symbol. Epsilon transitions are useful in many applications such as text normalization.

## Building Your Own WFST
You now have the basic tools to build your own WFST. Follow these steps to customize it for different applications:
1. Define your start, intermediate, and final states.
2. Add transitions with appropriate input/output symbols and weights.
3. Use epsilon transitions if needed.
4. Process input sequences to obtain output sequences and their associated costs.

Feel free to experiment and adapt the WFST structure for tasks like text normalization, speech processing, or sequence mapping.
