In [134]:
import base64
from IPython.display import Image, display
import re

TESTCASE1 = "SELECT [ ENAME = 'Mary' & DNAME = 'Research' ] ( EMPLOYEE JOIN DEPARTMENT )"
TESTCASE2 = "PROJECTION [ BDATE ] ( SELECT [ ENAME = 'John' & DNAME = ' Research' ] ( EMPLOYEE JOIN DEPARTMENT) )"
TESTCASE3 = "SELECT [ ESSN = '01' ] ( PROJECTION [ ESSN, PNAME ] ( WORKS_ON JOIN PROJECT ) )"

RELATION_DEF = {
    "EMPLOYEE": ["ENAME", "BDATE"],
    "DEPARTMENT": ["DNAME"],
    "PROJECT": ["PNAME"],
    "WORKS_ON": ["ESSN"]
}

In [157]:
def mm(graph):
  '''Render mermaid graph'''
  graphbytes = graph.encode("ascii")
  base64_bytes = base64.b64encode(graphbytes)
  base64_string = base64_bytes.decode("ascii")
  display(Image(url="https://mermaid.ink/img/" + base64_string))

class Node:
    name = "A"
    def __init__(self, node_type, params, children, name=None, parent=None):
        self.node_type = node_type
        self.params = params
        self.children = children if children else []
        self.parent = parent
        for child in self.children:
            child.parent = self
        if name:
            self.name = name
        else:
            self.name = Node.name
            Node.name = chr(ord(Node.name) + 1)
        self.colunms = None

    def update_colunms(self):
        '''Calculate available colunms for each node'''
        for child in self.children:
            child.update_colunms()
        if self.node_type == "RELATION":
            self.colunms = RELATION_DEF[self.name]
        elif self.node_type == "PROJECTION":
            self.colunms = self.children[0].colunms
            project_cols = [s.strip() for s in self.params.split(",")]
            self.colunms = [colunm for colunm in self.colunms if colunm in project_cols]
        elif self.node_type == "SELECT":
            self.colunms = self.children[0].colunms
        elif self.node_type == "JOIN":
            self.colunms = self.children[0].colunms + self.children[1].colunms

    def allChildren(self):
        '''Apply func to all children'''
        def _collect_children(node, children):
            '''Collect all children'''
            for child in node.children:
                child.parent = node
                _collect_children(child, children)
                children.append(node)
        all_children = [self]
        _collect_children(self, all_children)
        for child in all_children:
            yield child

    def addChild(self, child):
        '''Add a child to the node'''
        child.parent = self
        self.children.append(child)
    
    def removeChild(self, child):
        '''Remove a child from the node'''
        self.children.remove(child)

    def decompose(self):
        '''Turn SELECT [A & B] (xxxx) into SELECT [A] (SELECT [B] (xxxx))'''
        if self.node_type != "SELECT" or "&" not in self.params:
            return
        params = self.params
        self.params = params.split("&")[0].strip()
        self.children = [Node("SELECT", params=params.split("&")[1].strip(), children=[self.children[0]], parent=self)]

    def push_down(self, verbose=False) -> bool:
        '''
        1. Turn SELECT [A] (PROJECTION [B] (xxxx)) into PROJECTION [B] (SELECT [A] (xxxx))
        2. Turn SELECT [A] (RA JOIN RB) into RA JOIN SELECT [A] (RB) if applicable
        3. Turn PROJECTION [A] (RA JOIN RB) into RA JOIN PROJECTION [A] (RB) if applicable
        4. Turn PROJECTION [A, B] (RA JOIN RB) into JOIN (PROJECTION [A] (RA)) (PROJECTION [B] (RB)) if applicable
        return true if any change is made
        '''
                
        if self.node_type == "SELECT":
            # case 1
            if self.children[0].node_type == "PROJECTION":
                if verbose:
                    print(f"case 1: Push SELECT {self.name} down to PROJECTION {self.children[0].name}")
                projection_node = self.children[0]
                self.parent.removeChild(self)
                self.parent.addChild(projection_node)
                self.removeChild(projection_node)
                self.addChild(projection_node.children[0])
                projection_node.removeChild(projection_node.children[0])
                projection_node.addChild(self)
                return True
            # case 2
            elif self.children[0].node_type == "JOIN":
                if verbose:
                    print(f"case 2: Push SELECT {self.name} down to JOIN {self.children[0].name}")
                col_name = self.params.split("=")[0].strip()
                join_node = self.children[0]
                for idx, child in enumerate(join_node.children):
                    if col_name in child.colunms:
                        self.parent.removeChild(self)
                        self.parent.addChild(join_node)
                        join_node.children[idx], self.children[0] = self, join_node.children[idx]
                        join_node.parent, self.parent = self, join_node.parent
                        return True
        if self.node_type == "PROJECTION" and self.children[0].node_type == "JOIN":
            # case 3
            if verbose:
                print(f"case 3&4: Push PROJECTION {self.name} down to JOIN {self.children[0].name}")
            old_join_node = self.children[0]
            new_join_node = Node("JOIN", params=None, children=[])
            col_names = [s.strip() for s in self.params.split(",")]
            # Create multiple PROJECTION nodes
            changed = False
            for idx, child in enumerate(old_join_node.children):
                added_projection = False
                for col in col_names:
                    if col in child.colunms:
                        new_join_node.addChild(Node("PROJECTION", params=col, children=[child]))
                        added_projection = True
                        changed = True
                if not added_projection:
                    new_join_node.addChild(child)
            if changed:
                    self.parent.removeChild(self)
                    self.parent.addChild(new_join_node)
            return changed
        return False

    def __repr__(self):
        if self.children:
            return f"{self.node_type}({', '.join(map(str, self.children))})"
        else:
            return f"{self.node_type}({self.name})"

    def mermaid(self):

        def dfs(node:Node, mermaid_code):
            if node.node_type == "JOIN":
                mermaid_code.append("%s{%s}" % (node.name, node.node_type))
            else:
                mermaid_code.append(f"{node.name}[{node.node_type}-{node.name}]")
            for idx, child in enumerate(node.children):
                mermaid_code.extend(dfs(child, []))
                params_str = f"|{node.params}|" if node.params else ""
                mermaid_code.append(f"{child.name}-->{params_str}{node.name}")
            return mermaid_code

        mermaid_code = ["graph BT"]
        mermaid_code.extend(dfs(self, []))
        mermaid_str = "\n".join(mermaid_code)
        # print(mermaid_str)
        mm(mermaid_str)

def parse_query(query, verbose=False):
    # 匹配 SELECT、PROJECTION、JOIN 等关键词
    if verbose:
        print("Parsing: ", query)
    match = re.search(
        r"SELECT|PROJECTION|JOIN", query)
    if match:
        node_type = match.group(0)
        params = None
        if node_type in ["SELECT", "PROJECTION"]:
            params = re.search(r'\[([^\[\]]*(?:\[[^\[\]]*\])?[^\[\]]*)\]', query).group(1).strip()
            child_str = re.search(r'\((.*)\)', query).group(1).strip()
            if verbose:
                print("Found: ", node_type, "+", params, "+", child_str)
            child = parse_query(child_str, verbose)
            if child is None:
                return Node(node_type, params, [])
            else:
                return Node(node_type, params, [child])
        elif node_type == "JOIN":
            left, right = re.split(r"\s*JOIN\s*", query)
            if verbose:
                print("Found: ", node_type, " + ", left, " + ", right)
            return Node(node_type, None, [parse_query(left), parse_query(right)])
    # is a relation name
    return Node("RELATION", None, [], query.strip())

In [160]:
print("Initial:")
tree_root = Node("ROOT", None, [])
tree_root.addChild(parse_query(TESTCASE3))
tree_root.mermaid()

print("After decompose:")
for node in tree_root.allChildren():
    node.decompose()
tree_root.mermaid()

for i in range(1, 10):
    print(f"Push down for the {i}th time:")
    changed = False
    for node in tree_root.allChildren():
        tree_root.update_colunms()
        changed |= node.push_down(verbose=True)    
    if changed:
        tree_root.mermaid()
    else:
        print("No more changes!")
        break
print("Finished!")

Initial:


After decompose:


Push down for the 1th time:
case 3&4: Push PROJECTION N down to JOIN M
case 2: Push SELECT O down to JOIN P


Push down for the 2th time:
case 1: Push SELECT O down to PROJECTION Q


Push down for the 3th time:
No more changes!
Finished!
