# Construct Join Statement for SQL


In [21]:
from dataclasses import dataclass
from typing import *

In [83]:
@dataclass
class Table:
    """SQL Table"""
    def __init__(self, database: str, name: str, alias: str = ""):
        self.database = database
        self.name = name
        self.alias = alias if alias else self.name

@dataclass
class JoinStep:
    source_table: Table
    target_table: Table
    join_fields: Tuple[str, str]


class SQLJoinConstructor:
    """SQL Join Constructor"""
    @classmethod
    def construct_join_statement(cls, pipeline: Sequence[JoinStep], existed_tables: set[Table] | None = None, depth: int = 0) -> str:
        """Construct SQL join statement"""
        join_query = ""
        if len(pipeline) == 0:
            return ""
        if depth > 20:
            raise Exception("Maximum depth of recursion reached.")
        local_existed_tables = set() if existed_tables is None else existed_tables
        current_step = pipeline[0]
        remaining_steps = pipeline[1:]
        source_table, target_table = current_step.source_table, current_step.target_table
        target_table_full_name = f"{target_table.database}.{target_table.name}"
        source_table_full_name = f"{source_table.database}.{source_table.name}"
        if target_table_full_name not in local_existed_tables:
            return cls.construct_join_statement(
                pipeline=[*remaining_steps, current_step], 
                existed_tables=local_existed_tables,
                depth=depth+1,
            )
        source_field, target_field = current_step.join_fields[0], current_step.join_fields[1]
        join_statement = f"""
        INNER JOIN {source_table.database}.{source_table.name} as {source_table.alias}
            ON {source_table.alias}.{source_field} = {target_table.alias}.{target_field}
        """
        join_query = join_statement
        local_existed_tables.add(source_table_full_name)

        return join_query + cls.construct_join_statement(
                pipeline=remaining_steps, 
                existed_tables=local_existed_tables,
                depth=depth+1,
            )

In [72]:
database = "applydb_prod"
table_a = Table(database=database, name="table_a")
table_b = Table(database=database, name="table_b")
table_c = Table(database=database, name="table_c")
table_d = Table(database=database, name="table_d")
join_conf = [
    JoinStep(table_d, table_c, ("column_3", "column_3")),
    JoinStep(table_c, table_b, ("column_2", "column_2")),
    JoinStep(table_b, table_a, ("column_1", "column_1")),
]

In [84]:
join_query = SQLJoinConstructor.construct_join_statement(
    pipeline=join_conf,
    existed_tables=set([f"{table_a.database}.{table_a.name}"])
)
print(f"join_query: {join_query}")

join_query: 
        INNER JOIN applydb_prod.table_b as table_b
            ON table_b.column_1 = table_a.column_1
        
        INNER JOIN applydb_prod.table_c as table_c
            ON table_c.column_2 = table_b.column_2
        
        INNER JOIN applydb_prod.table_d as table_d
            ON table_d.column_3 = table_c.column_3
        
