In [3]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# Create figure and axes
fig, ax = plt.subplots(figsize=(12, 8))

# Define node data
nodes = [
    {"name": "dim_date.csv", "content": "dim_date.csv", "x": 1, "y": 6, "color": "lightgreen"},
    {"name": "dim_product.csv", "content": "dim_product.csv", "x": 3, "y": 6, "color": "lightgreen"},
    {"name": "dim_store.csv", "content": "dim_store.csv", "x": 5, "y": 6, "color": "lightgreen"},
    {"name": "fact_sales.csv", "content": "fact_sales.csv", "x": 7, "y": 6, "color": "lightgreen"},
    {"name": "ETL Pipeline", "content": "ETL Pipeline (etl.py)\nExtract: Read CSVs\nTransform: Clean data\nLoad: Insert into MySQL", "x": 4, "y": 3, "color": "lightcoral"},
    {"name": "MySQL", "content": "MySQL (retail_dw)\ndim_date\ndim_product\ndim_store\nfact_sales", "x": 4, "y": 1, "color": "lightblue"},
    {"name": "Loaded Tables", "content": "Loaded Tables", "x": 4, "y": -1, "color": "lightyellow"}
]

# Function to draw a node box
def draw_node(ax, node):
    width = 2.5 if node["name"] != "ETL Pipeline" and node["name"] != "MySQL" else 3
    height = len(node["content"].split('\n')) * 0.5
    rect = patches.Rectangle((node["x"], node["y"]), width, height, linewidth=1, edgecolor='black', facecolor=node["color"])
    ax.add_patch(rect)
    lines = node["content"].split('\n')
    for i, line in enumerate(lines):
        ax.text(node["x"] + width/2, node["y"] + height - (i * 0.5 + 0.25), line, ha='center', va='center', fontsize=9)

# Draw nodes
for node in nodes:
    draw_node(ax, node)

# Draw arrows
def draw_arrow(ax, start, end):
    ax.annotate(
        '', xy=(end[0], end[1]), xytext=(start[0], start[1]),
        arrowprops=dict(arrowstyle='->', color='black')
    )

# Arrows from CSVs to ETL
draw_arrow(ax, (1.5, 6), (4.5, 3.5))
draw_arrow(ax, (3.5, 6), (4.5, 3.5))
draw_arrow(ax, (5.5, 6), (4.5, 3.5))
draw_arrow(ax, (7.5, 6), (4.5, 3.5))
# Arrow from ETL to MySQL
draw_arrow(ax, (4.5, 2.5), (4.5, 1.5))
# Arrow from MySQL to Loaded Tables
draw_arrow(ax, (4.5, 0.5), (4.5, -0.5))

# Set plot limits and remove axes
ax.set_xlim(0, 10)
ax.set_ylim(-2, 8)
ax.axis('off')

# Save the diagram
plt.savefig('docs/workflow_diagram.png', bbox_inches='tight', dpi=300)
plt.close()
print("Workflow diagram saved as docs/workflow_diagram.png")

Workflow diagram saved as docs/workflow_diagram.png


In [4]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# Create figure and axes
fig, ax = plt.subplots(figsize=(14, 9))

# Define node data
nodes = [
    {"name": "dim_date.csv", "content": "dim_date.csv", "x": 1, "y": 7, "color": "lightgreen"},
    {"name": "dim_product.csv", "content": "dim_product.csv", "x": 3, "y": 7, "color": "lightgreen"},
    {"name": "dim_store.csv", "content": "dim_store.csv", "x": 5, "y": 7, "color": "lightgreen"},
    {"name": "fact_sales.csv", "content": "fact_sales.csv", "x": 7, "y": 7, "color": "lightgreen"},
    {"name": "ETL Pipeline", "content": "ETL Pipeline (etl.py)\nExtract: Read CSVs\nTransform: Clean data\nLoad: Insert into MySQL", "x": 4, "y": 4, "color": "lightcoral"},
    {"name": "MySQL", "content": "MySQL (retail_dw)\ndim_date\ndim_product\ndim_store\nfact_sales", "x": 4, "y": 1, "color": "lightblue"},
    {"name": "Loaded Tables", "content": "Loaded Tables", "x": 4, "y": -1.5, "color": "lightyellow"}
]

# Function to draw a node box
def draw_node(ax, node):
    width = 2.5 if node["name"] != "ETL Pipeline" and node["name"] != "MySQL" else 3
    height = len(node["content"].split('\n')) * 0.5
    rect = patches.Rectangle((node["x"], node["y"]), width, height, linewidth=1, edgecolor='black', facecolor=node["color"])
    ax.add_patch(rect)
    lines = node["content"].split('\n')
    for i, line in enumerate(lines):
        ax.text(node["x"] + width/2, node["y"] + height - (i * 0.5 + 0.25), line, ha='center', va='center', fontsize=9)
    return node["x"] + width / 2, node["y"] + height / 2

# Draw nodes and get centers (though not strictly needed here with direct coordinates)
node_centers = {node["name"]: draw_node(ax, node) for node in nodes}

# Draw arrows with labels
def draw_arrow(ax, start_xy, end_xy, label="", text_offset=(0, 0)):
    arrowprops = dict(arrowstyle='->', color='black', connectionstyle='arc3,rad=0.1')
    ax.annotate(
        '', xy=end_xy, xytext=start_xy,
        arrowprops=arrowprops,
        fontsize=8
    )
    if label:
        text_x = (start_xy[0] + end_xy[0]) / 2 + text_offset[0]
        text_y = (start_xy[1] + end_xy[1]) / 2 + text_offset[1]
        ax.text(text_x, text_y, label, ha='center', va='center', fontsize=8)

# Arrows from CSVs to ETL
draw_arrow(ax, (1.5, 7), (4, 4.5), "Extract")
draw_arrow(ax, (3.5, 7), (4, 4.5), "Extract")
draw_arrow(ax, (5.5, 7), (4, 4.5), "Extract")
draw_arrow(ax, (7.5, 7), (4, 4.5), "Extract")

# Arrow from ETL to MySQL
draw_arrow(ax, (4, 3.5), (4, 1.5), "Load")

# Arrow from MySQL to Loaded Tables (representing the result)
draw_arrow(ax, (4, 0.5), (4, -1), "Result")

# Set plot limits and remove axes
ax.set_xlim(0, 9)
ax.set_ylim(-2.5, 8)
ax.axis('off')
ax.set_title('Data Pipeline Workflow', fontsize=14)

# Save the diagram
plt.savefig('docs/workflow_diagram_improved.png', bbox_inches='tight', dpi=300)
plt.close()
print("Improved workflow diagram saved as docs/workflow_diagram_improved.png")

Improved workflow diagram saved as docs/workflow_diagram_improved.png
