In [235]:
import pandas as pd
from graphviz import Digraph
import pydotplus
import networkx as nx
from enum import Enum
from copy import deepcopy

In [236]:
table = 'datapipeline.hist_cp_fill'

In [237]:
class Shape(Enum):
    source = {'shape':'record', 'fillcolor': '#90EE90', 'style': 'filled'}
    derived  = {'shape':'oval', 'fillcolor': '#E6F5FF', 'style': 'filled'}
    UNK  = {'shape':'hexagon', 'fillcolor': '#FFB6C1', 'style': 'filled'}


In [238]:
from collections import defaultdict


def dictionary_factory():
    return {}

def get_ancestors(graph, node_name) -> set:
    ancestors = set(graph.in_edges(node_name))
    for ancestor in list(ancestors):
        ancestors |= get_ancestors(graph, ancestor[0])
    return ancestors

class DependencyGraph:
    def __init__(self,table, subgraph: str|None = None):
        self.table = table
        self.subgraph = subgraph   
        self.create_graph()

    def create_graph(self) -> None:
        self.create_empty_dag()
        node_df = self._load_graph_data()
        for _, row in node_df.iterrows():
            self.add_to_dag(self.shape_lookup, row)
        
        if self.subgraph is not None:
            nx_graph = self._to_networkx() 
            ancestors = get_ancestors(nx_graph, self.subgraph)
            self.create_empty_dag()       
            for _, row in node_df.iterrows():
                if (row['source'], row['target']) in ancestors:
                    self.add_to_dag(self.shape_lookup, row)

    def create_empty_dag(self):
        self.dot = Digraph(format='svg', graph_attr={'rankdir':'LR'})
        self.nodes = dictionary_factory()

    def add_to_dag(self, shape_lookup: dict, row: pd.Series):
        shape = (
            Shape[shape_lookup.get(row['source'], 'derived')].value,
            Shape[shape_lookup.get(row['target'], 'derived')].value
            )
        self.create_node(row, shape)
        self.dot.edge(row['source'], row['target'])

    def _to_networkx(self):
        dotplus = pydotplus.graph_from_dot_data(self.dot.source)
        nx_graph = nx.nx_pydot.from_pydot(dotplus)
        return nx_graph


    def _load_graph_data(self) -> pd.DataFrame:
        try:
            df = pd.read_csv(f'{self.table}.csv')
            self.shape_lookup = dict(zip(df.source, df.type))
        except FileNotFoundError:
            raise FileNotFoundError(f"{self.table}.csv not found.")
        return df
    
    def draw_dot(self) -> None:
        return self.dot

    def render_dot(self) -> None:
        if self.subgraph is not None:
            self.dot.render(filename=f"graphs/{self.subgraph}")
        else:
            self.dot.render(filename=f"graphs/{self.table}")
        return


    def create_node(self, row: pd.Series, shape: tuple) -> None:
        self._create_node(row['source'], shape[0], row['URL'])
        self._create_node(row['target'], shape[1])
    
    def _create_node(self, node_name: str, shape: dict, url: str|None = None) -> None:
        if node_name not in self.nodes:
            node_attrs = {'name': node_name, 'label': node_name, 'URL': url}
            node_attrs.update(shape)
            self.dot.node(**node_attrs)
            self.nodes[node_name] = True
    



In [239]:
graph = DependencyGraph(table)
graph.render_dot()