In [None]:
class WFST:
    def __init__(self):
        self.states = {}
        self.start_state = None
        self.final_states = set()

    def set_start_state(self, state):
        self.start_state = state
        self.add_state(state)

    def add_state(self, state):
        if state not in self.states:
            self.states[state] = {}

    def add_final_state(self, state):
        self.final_states.add(state)
        self.add_state(state)

    def add_transition(self, from_state, to_state, input_symbol, output_symbol, weight=0):
        self.add_state(from_state)
        self.add_state(to_state)
        if input_symbol not in self.states[from_state]:
            self.states[from_state][input_symbol] = []
        self.states[from_state][input_symbol].append((to_state, output_symbol, weight))

    def add_epsilon_transition(self, from_state, to_state, weight=0):
        self.add_transition(from_state, to_state, '', '', weight)

    def process(self, input_sequence):
        current_states = [(self.start_state, '', 0)]
        for symbol in input_sequence:
            next_states = []
            for (state, current_output, current_weight) in current_states:
                if symbol in self.states[state]:
                    for (next_state, output_symbol, weight) in self.states[state][symbol]:
                        next_states.append((next_state, current_output + output_symbol, current_weight + weight))
                if '' in self.states[state]:  
                    for (next_state, output_symbol, weight) in self.states[state]['']:
                        next_states.append((next_state, current_output + output_symbol, current_weight + weight))
            current_states = next_states
        return [(state, output, weight) for (state, output, weight) in current_states if state in self.final_states]

    def compose(self, other):
        result = WFST()
        result.set_start_state((self.start_state, other.start_state))
        queue = [(self.start_state, other.start_state)]
        visited = set(queue)

        while queue:
            (s1, s2) = queue.pop(0)
            for symbol in set(self.states[s1].keys()).union(other.states[s2].keys()):
                if symbol in self.states[s1] and symbol in other.states[s2]:
                    for (n1, o1, w1) in self.states[s1][symbol]:
                        for (n2, o2, w2) in other.states[s2][symbol]:
                            result.add_transition((s1, s2), (n1, n2), symbol, o1 + o2, w1 + w2)
                            if (n1, n2) not in visited:
                                queue.append((n1, n2))
                                visited.add((n1, n2))
                if symbol in self.states[s1] and '' in other.states[s2]:
                    for (n1, o1, w1) in self.states[s1][symbol]:
                        for (n2, o2, w2) in other.states[s2]['']:
                            result.add_transition((s1, s2), (n1, n2), symbol, o1 + o2, w1 + w2)
                            if (n1, n2) not in visited:
                                queue.append((n1, n2))
                                visited.add((n1, n2))
                if '' in self.states[s1] and symbol in other.states[s2]:
                    for (n1, o1, w1) in self.states[s1]['']:
                        for (n2, o2, w2) in other.states[s2][symbol]:
                            result.add_transition((s1, s2), (n1, n2), '', o1 + o2, w1 + w2)
                            if (n1, n2) not in visited:
                                queue.append((n1, n2))
                                visited.add((n1, n2))
                if '' in self.states[s1] and '' in other.states[s2]:
                    for (n1, o1, w1) in self.states[s1]['']:
                        for (n2, o2, w2) in other.states[s2]['']:
                            result.add_transition((s1, s2), (n1, n2), '', o1 + o2, w1 + w2)
                            if (n1, n2) not in visited:
                                queue.append((n1, n2))
                                visited.add((n1, n2))

        for f1 in self.final_states:
            for f2 in other.final_states:
                result.add_final_state((f1, f2))
        return result


In [7]:
# Create WFST for "twenty" -> "20"
wfst1 = WFST()
wfst1.set_start_state('q0')
wfst1.add_final_state('q1')
wfst1.add_transition('q0', 'q1', 'twenty', '20')

wfst2 = WFST()
wfst2.set_start_state('q0')
wfst2.add_final_state('q1')
wfst2.add_transition('q0', 'q1', 'twenty-one', '21')

result1 = wfst1.process(['twenty'])
result2 = wfst2.process(['twenty-one'])

combined_result = []
for (state1, output1, weight1) in result1:
    for (state2, output2, weight2) in result2:
        combined_result.append((state2, output1 + output2, weight1 + weight2))

print(combined_result) 

[('q1', '2021', 0)]
