In [None]:
from matplotlib import pyplot as plt
import json
import os
import numpy as np

In [None]:
def plot_network_traffic(metrics, main_title):
    """
    Plots traffic, db_size, and total_state_size for each node in a multi-row grid,
    plus summary bar charts at the bottom.

    Expected structure:
    {
        "outgoing_traffic": { node: [[t, size], ...], ... },
        "incoming_traffic": { node: [[t, size], ...], ... },
        "db_sizes": { node: [[t, size], ...], ... },
        "total_state_sizes": { node: [[t, size], ...], ... }
    }
    """

    outgoing = metrics.get("outgoing_traffic", {})
    incoming = metrics.get("incoming_traffic", {})
    db_size = metrics.get("db_sizes", {})
    state_size = metrics.get("total_state_sizes", {})

    nodes = sorted(set(outgoing.keys()) |
                   set(incoming.keys()) |
                   set(db_size.keys()) |
                   set(state_size.keys()))
    n = len(nodes)
    if n == 0:
        print("No node data found.")
        return

    # Grid: 7 rows total — 4×N line plots + 3 bar charts
    fig = plt.figure(figsize=(5*n, 18))
    gs = fig.add_gridspec(7, n, height_ratios=[1, 1, 1, 1, 0.5, 0.5, 0.5])
    fig.suptitle(main_title, fontsize=18, weight="bold")

    # Helper to extract time/sizes sorted by time (accounting for reverse order)
    def unpack(data):
        if not data:
            return [], []
        t, s = zip(*sorted(data, key=lambda x: x[0]))
        return list(t), list(s)

    # Store total values per node
    total_in = {}
    total_out = {}
    total_db = {}
    total_state = {}

    # --- Per-node time-series plots ---
    for i, node in enumerate(nodes):
        out_t, out_s = unpack(outgoing.get(node, []))
        in_t, in_s = unpack(incoming.get(node, []))
        db_t, db_s = unpack(db_size.get(node, []))
        st_t, st_s = unpack(state_size.get(node, []))

        # --- Outgoing traffic ---
        ax_out = fig.add_subplot(gs[0, i])
        ax_out.plot(out_t, out_s, color="tab:blue", marker="o", label="Outgoing")
        ax_out.set_title(f"Node {node} - Outgoing Traffic")
        ax_out.set_xlabel("Time")
        ax_out.set_ylabel("Bytes")
        ax_out.grid(True, linestyle="--", alpha=0.5)

        # --- Incoming traffic ---
        ax_in = fig.add_subplot(gs[1, i])
        ax_in.plot(in_t, in_s, color="tab:green", marker="o", label="Incoming")
        ax_in.set_title(f"Node {node} - Incoming Traffic")
        ax_in.set_xlabel("Time")
        ax_in.set_ylabel("Bytes")
        ax_in.grid(True, linestyle="--", alpha=0.5)

        # --- DB size ---
        ax_db = fig.add_subplot(gs[2, i])
        ax_db.plot(db_t, db_s, color="tab:orange", marker="o", label="DB Size")
        ax_db.set_title(f"Node {node} - DB Size")
        ax_db.set_xlabel("Time")
        ax_db.set_ylabel("Bytes")
        ax_db.grid(True, linestyle="--", alpha=0.5)

        # --- Total state size ---
        ax_state = fig.add_subplot(gs[3, i])
        ax_state.plot(st_t, st_s, color="tab:red", marker="o", label="State Size")
        ax_state.set_title(f"Node {node} - State Size")
        ax_state.set_xlabel("Time")
        ax_state.set_ylabel("Bytes")
        ax_state.grid(True, linestyle="--", alpha=0.5)

        # --- Totals for bar charts ---
        total_in[node] = sum(in_s)
        total_out[node] = sum(out_s)
        total_db[node] = db_s[-1] if db_s else 0
        total_state[node] = st_s[-1] if st_s else 0

    # --- Bar charts (rows 5–7) ---
    node_labels = [str(n) for n in nodes]
    x = np.arange(len(nodes))

    # Row 5: Total Traffic
    ax_traffic = fig.add_subplot(gs[4, :])
    ax_traffic.bar(x, [total_in[n] + total_out[n] for n in nodes],
                   color="tab:purple", alpha=0.7)
    ax_traffic.set_title("Total Network Traffic per Node (Incoming + Outgoing)")
    ax_traffic.set_ylabel("Total Bytes")
    ax_traffic.set_xticks(x)
    ax_traffic.set_xticklabels(node_labels)
    ax_traffic.grid(True, axis="y", linestyle="--", alpha=0.5)

    # Row 6: Final State Size
    ax_state_bar = fig.add_subplot(gs[5, :])
    ax_state_bar.bar(x, [total_state[n] for n in nodes],
                     color="tab:red", alpha=0.7)
    ax_state_bar.set_title("Final Total State Size per Node")
    ax_state_bar.set_ylabel("Bytes")
    ax_state_bar.set_xticks(x)
    ax_state_bar.set_xticklabels(node_labels)
    ax_state_bar.grid(True, axis="y", linestyle="--", alpha=0.5)

    # Row 7: Final DB Size
    ax_db_bar = fig.add_subplot(gs[6, :])
    ax_db_bar.bar(x, [total_db[n] for n in nodes],
                  color="tab:orange", alpha=0.7)
    ax_db_bar.set_title("Final DB Size per Node")
    ax_db_bar.set_ylabel("Bytes")
    ax_db_bar.set_xticks(x)
    ax_db_bar.set_xticklabels(node_labels)
    ax_db_bar.grid(True, axis="y", linestyle="--", alpha=0.5)

    plt.tight_layout(rect=[0, 0.03, 1, 0.96])
    plt.show()

In [None]:
# def plot_network_traffic(metrics, main_title):
#     """
#     Plots incoming and outgoing traffic per node in a 2×N grid.
    
#     metrics: dict with keys 'incoming_traffic' and 'outgoing_traffic'.
#              Each maps node -> list of [timestamp, msg_size].
#     """
#     outgoing = metrics.get("outgoing_traffic", {})
#     incoming = metrics.get("incoming_traffic", {})
    
#     nodes = sorted(set(outgoing.keys()) | set(incoming.keys()))
#     n = len(nodes)

#     fig, axes = plt.subplots(2, n, figsize=(5*n, 8), sharex=False)
#     fig.suptitle(main_title, fontsize=16)

#     # Handle the case when n == 1 (axes shape differences)
#     if n == 1:
#         axes = axes.reshape(2, 1)

#     for i, node in enumerate(nodes):
#         out_data = outgoing.get(node, [])
#         in_data = incoming.get(node, [])
        
#         # Unpack time and sizes
#         if out_data:
#             t_out, s_out = zip(*sorted(out_data))
#         else:
#             t_out, s_out = [], []
#         if in_data:
#             t_in, s_in = zip(*sorted(in_data))
#         else:
#             t_in, s_in = [], []

#         # Plot outgoing traffic
#         axes[0, i].plot(t_out, s_out, label='Outgoing', color='tab:blue', marker='o')
#         axes[0, i].set_title(f"Node {node} - Outgoing")
#         axes[0, i].set_xlabel("Time")
#         axes[0, i].set_ylabel("Message Size (bytes)")
#         axes[0, i].grid(True, linestyle='--', alpha=0.5)

#         # Plot incoming traffic
#         axes[1, i].plot(t_in, s_in, label='Incoming', color='tab:green', marker='o')
#         axes[1, i].set_title(f"Node {node} - Incoming")
#         axes[1, i].set_xlabel("Time")
#         axes[1, i].set_ylabel("Message Size (bytes)")
#         axes[1, i].grid(True, linestyle='--', alpha=0.5)

#     plt.tight_layout(rect=[0, 0.03, 1, 0.95])
#     plt.show()

In [None]:
def plot_metrics_from_dir(directory):
    json_files = [f for f in os.listdir(directory) if f.endswith(".json")]
    if not json_files:
        print(f"No JSON files found in '{directory}'.")
        return

    for fname in sorted(json_files):
        fpath = os.path.join(directory, fname)
        try:
            with open(fpath, "r") as f:
                data = json.load(f)
        except Exception as e:
            print(f"⚠️ Skipping '{fname}' (could not load JSON): {e}")
            continue

        print(f"📊 Plotting metrics from: {fname}")
        plot_network_traffic(data, fname)

In [None]:
plot_metrics_from_dir("../metrics")
