In [None]:
import os
import sys
sys.path.append("..")

import networkx as nx
import matplotlib as plt

from trace_analyzer import TraceAnalyzer

sim_result_root = "../sim_result"


In [None]:
def get_real_tile_latency(trace_analyzer: TraceAnalyzer, taskname:str, layer:int, batch:int, tile=None):
    graph = trace_analyzer.graph.get_graph()

    G = graph.subgraph(nodes=[n for n, nattr in graph.nodes(data=True)
        if nattr["layer"] == layer
        and nattr["batch"] == batch
        ])

    assert len(G.nodes) > 0

    wsrc = [n for n, attr in G.nodes(data=True) if attr["op_type"] == "wsrc"]
    assert len(wsrc) == 1
    wsrc = wsrc[0]
    insrc = [n for n, attr in G.nodes(data=True) if attr["op_type"] == "insrc"]
    assert len(insrc) == 1
    insrc = insrc[0]
    workers = [n for n, attr in G.nodes(data=True) if attr["op_type"] == "worker"]
    # sink = [n for n, attr in G.nodes(data=True) if attr["op_type"] == "sink"][0]

    cnt = max(int(G.nodes[wsrc]["cnt"]), int(G.nodes[insrc]["cnt"]))

    if isinstance(tile, int):
        assert tile < cnt 

        start = float("inf")
        end = -float("inf")

        for w in workers:
            w_edges = G.edges[wsrc, w]["pkt"]
            w_pid = sorted(list(w_edges.keys()))[tile]
            w_pkt = w_edges[w_pid]

            in_edges = G.edges[insrc, w]["pkt"]
            in_pid = sorted(list(in_edges.keys()))[tile]
            in_pkt = in_edges[in_pid]

            start = min(start, w_pkt["start_cycle"])
            start = min(start, in_pkt["start_cycle"])
            end = max(start, w_pkt["end_cycle"])
            end = max(start, in_pkt["end_cycle"])
        
        return end-start
    
    elif tile == None:
        
        start = [float("inf")] * cnt
        end = [-float("inf")] * cnt

        for w in workers:
            w_edges = G.edges[wsrc, w]["pkt"]
            w_pids = sorted(list(w_edges.keys()))

            in_edges = G.edges[insrc, w]["pkt"]
            in_pids = sorted(list(in_edges.keys()))

            w_cnt = int(G.nodes[wsrc]["cnt"])
            for t in range(w_cnt):
                w_pkt = w_edges[w_pids[t]]
                start[t] = min(start[t], w_pkt["start_cycle"])
                end[t] = max(end[t], w_pkt["end_cycle"])

            in_cnt = int(G.nodes[insrc]["cnt"])
            for t in range(in_cnt):
                in_pkt = in_edges[in_pids[t]]
                start[t] = min(start[t], in_pkt["start_cycle"])
                end[t] = max(end[t], in_pkt["end_cycle"])
        
        return [e - s for s, e in zip(start, end)]


for root, dirs, files in os.walk(sim_result_root):
    if len(files) < 4:
        continue  # no out.log

    taskname = os.path.split(root)[1]
    print(taskname)
    trace_analyzer = TraceAnalyzer(taskname)

    for layer in trace_analyzer.get_layers():
        res = get_real_tile_latency(trace_analyzer, taskname, layer, 0)
        print(res)
    
    break