union find + backtracking

In [None]:
# from collections import List 
from typing import List

class UnionFind:
    # Initialize parents
    def __init__(self):
        self.parent = {}

    # Find root of node x with path compression
    def find(self, word):
        if word not in self.parent:
            self.parent[word] = word
        if self.parent[word] != word:
            self.parent[word] = self.find(self.parent[word])
        return self.parent[word]

    # Merge 2 set in same group
    def union(self, word1, word2):
        root1 = self.find(word1)
        root2 = self.find(word2)

        if root1 != root2:
            self.parent[root2] = root1

    # Check if they are in the same group
    def connected(self, word1, word2):
        return self.find(word1) == self.find(word2)

class Solution:
    def generateSentences(self, synonyms: List[List[str]], text: str) -> List[str]:
        # merge 2 words with root word1 
        union_find = UnionFind()
        for word1, word2 in synonyms:
            union_find.union(word1, word2)


        # build a synonym group by root parents
        synonym_group = {}
        # example parent root: [group]
        # {"happy": ['happy', 'joy', 'cheerful']}
        for word in union_find.parent.keys():
            root = union_find.find(word)
            if root not in synonym_group: synonym_group[root] = []
            synonym_group[root].append(word)

        # Sort synonym group value lists
        for group in synonym_group.values():
            group.sort()


        # To parse text, create another dictionary including value: [group]
        # example: value: [group]
        # cheerful: ['happy', 'joy', 'cheerful']
        # all synonym to extract
        synonyms = {}
        for group in synonym_group.values():
            for word in group:
                if word in group:
                    if word not in synonyms: synonyms[word] = []
                    synonyms[word] = group

        # backtrack to generate sentences
        words = text.split()
        result = []
        current_sentence = []
        def backtrack(index):
            # Goal
            if index == len(words):
                result.append(" ".join(current_sentence))
                return 

            word = words[index]
            # there is synonym in dictionary
            if word in synonyms:
                for synonym in synonyms[word]:
                    current_sentence.append(synonym)
                    backtrack(index+1)
                    current_sentence.pop()
            # there is no synonym. keep it as it is
            else:
                current_sentence.append(word)
                backtrack(index+1)
                current_sentence.pop()
        backtrack(0)

        return sorted(result)