In [None]:
# -*- coding: utf-8 -*-
"""
q3_app_balanced.py
Objective: Significantly reduce runtime while maintaining optimization quality
- Single enhanced list scheduling (O((N+E) log N))
- Critical-path adjacent fine-tuning (few attempts, negligible overhead)
- Preserves Problem2's memory/spill assignments; only performs topologically equivalent reordering
- Compatible with PyQt5 / PyQt6
"""
import os, sys, csv, re
from pathlib import Path
from collections import defaultdict, deque
import heapq

# ===================== Tunable Parameters (Balance Speed/Quality) =====================
# Number of critical-path local tuning passes (0/1/2): 2 usually yields better results quickly; set to 0 for extreme speed
CP_LOCAL_PASSES = 2
# Max number of swap attempts per pass (prevents excessive time on huge graphs)
CP_LOCAL_TRY_LIMIT = 256

# ===================== Qt Compatibility Import =====================
try:
    from PyQt5 import QtWidgets, QtCore
    from PyQt5.QtWidgets import (
        QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel,
        QFileDialog, QTableWidget, QTableWidgetItem, QMessageBox, QTextEdit, QSplitter, QSizePolicy, QComboBox
    )
    from PyQt5.QtCore import QObject, pyqtSignal, QThread
    QT_LIB = "PyQt5"
except Exception:
    from PyQt6 import QtWidgets, QtCore
    from PyQt6.QtWidgets import (
        QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel,
        QFileDialog, QTableWidget, QTableWidgetItem, QMessageBox, QTextEdit, QSplitter, QSizePolicy, QComboBox
    )
    from PyQt6.QtCore import QObject, pyqtSignal, QThread
    QT_LIB = "PyQt6"

# ===================== Matplotlib Setup =====================
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use("Agg")  # Default to headless; switch to QtAgg when creating windows

if QT_LIB == "PyQt5":
    from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
else:
    from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure

plt.rcParams['font.sans-serif'] = [
    'SimHei', 'Microsoft YaHei', 'PingFang SC', 'Heiti SC',
    'WenQuanYi Micro Hei', 'Noto Sans CJK JP', 'Arial Unicode MS'
]
plt.rcParams['axes.unicode_minus'] = False

APP_TITLE = f"Problem 3: Runtime Optimization (Balanced Version, {QT_LIB})"

# ===================== Constants =====================
CAPACITY = {"L1":4096, "UB":1024, "L0A":256, "L0B":256, "L0C":512}
TASKS = [
    "Conv_Case0","Conv_Case1",
    "FlashAttention_Case0","FlashAttention_Case1",
    "Matmul_Case0","Matmul_Case1",
]

# ===================== Utility Functions =====================
def norm_l0_name(x: str) -> str:
    if not x:
        return x
    u = x.strip().upper()
    return u.replace("LOA","L0A").replace("LOB","L0B").replace("LOC","L0C")

def read_nodes_edges(csv_dir: Path, task: str):
    nodes_path = csv_dir / f"{task}_Nodes.csv"
    edges_path = csv_dir / f"{task}_Edges.csv"
    if not nodes_path.exists() or not edges_path.exists():
        raise FileNotFoundError(f"Missing CSV files: {nodes_path} or {edges_path}")

    nodes = {}
    buf_attr, op_pipe, cycles = {}, {}, {}
    node_uses = defaultdict(list)
    alloc_node_of, free_node_of = {}, {}
    copyin_uses_buf = set()
    bad_bufs = 0

    with open(nodes_path, newline="", encoding="utf-8") as f:
        for row in csv.DictReader(f):
            nid = int(row["Id"])
            op = (row.get("Op") or "").strip().upper()
            nodes[nid] = row
            p = (row.get("Pipe") or "").strip().upper()
            op_pipe[nid] = norm_l0_name(p) if p else None
            try:
                cycles[nid] = int(row.get("Cycles") or 0)
            except Exception:
                cycles[nid] = 0

            if op == "ALLOC":
                b = int(row["BufId"]); t = norm_l0_name(row.get("Type") or ""); s = int(row["Size"])
                buf_attr[b] = (t, s); alloc_node_of[b] = nid
            elif op == "FREE":
                b = int(row["BufId"]); free_node_of[b] = nid
            else:
                bufs_str = (row.get("Bufs") or "").strip()
                if bufs_str:
                    cleaned = re.sub(r'[ 

$$

 \s"]', '', bufs_str)
                    if cleaned:
                        for tok in cleaned.split(','):
                            if not tok: continue
                            try:
                                b = int(tok)
                                node_uses[nid].append(b)
                                if op == "COPY_IN":
                                    copyin_uses_buf.add(b)
                            except Exception:
                                bad_bufs += 1

    edges = []
    with open(edges_path, newline="", encoding="utf-8") as f:
        for row in csv.DictReader(f):
            u, v = int(row["StartNodeId"]), int(row["EndNodeId"])
            edges.append((u, v))

    n_max = max(nodes.keys()) if nodes else -1
    meta = {
        "nodes": nodes, "edges": edges, "buf_attr": buf_attr, "op_pipe": op_pipe,
        "cycles": cycles, "node_uses": node_uses,
        "alloc_node_of": alloc_node_of, "free_node_of": free_node_of,
        "copyin_uses_buf": copyin_uses_buf, "N": n_max + 1
    }
    if bad_bufs:
        print(f"[warn] Skipped {bad_bufs} invalid Bufs tokens")
    return meta

def read_problem_schedule(dir_path: Path, task: str, _tag_unused: str = ""):
    sch_path = dir_path / f"{task}_schedule.txt"
    if not sch_path.exists():
        return None
    seq = []
    with open(sch_path, "r", encoding="utf-8") as f:
        for ln in f:
            s = ln.strip()
            if s:
                try:
                    seq.append(int(s))
                except Exception:
                    pass
    return seq

def read_memory(dir_path: Path, task: str):
    m_path = dir_path / f"{task}_memory.txt"
    if not m_path.exists():
        return {}
    mem = {}
    with open(m_path, "r", encoding="utf-8") as f:
        for ln in f:
            s = ln.strip()
            if not s: continue
            if ":" in s:
                b, off = s.split(":", 1)
                try:
                    mem[int(b)] = int(off)
                except Exception:
                    pass
    return mem

def read_spill(dir_path: Path, task: str):
    sp_path = dir_path / f"{task}_spill.txt"
    if not sp_path.exists():
        return []
    ops = []
    with open(sp_path, "r", encoding="utf-8") as f:
        for ln in f:
            s = ln.strip()
            if not s: continue
            if ":" in s:
                b, off = s.split(":", 1)
                try:
                    ops.append((int(b), int(off)))
                except Exception:
                    pass
    return ops

def ensure_dirs(root_path: Path):
    csv_dir = root_path / "CSV"
    p1_dir = root_path / "Problem1"
    p2_dir = root_path / "Problem2"
    p3_dir = root_path / "Problem3"

    missing_dirs = []
    if not csv_dir.exists(): missing_dirs.append("CSV")
    if not p1_dir.exists(): missing_dirs.append("Problem1")
    if not p2_dir.exists(): missing_dirs.append("Problem2")
    if missing_dirs:
        raise FileNotFoundError(
            f"Selected directory '{root_path}' is missing required subdirectories: {', '.join(missing_dirs)}"
        )
    p3_dir.mkdir(parents=True, exist_ok=True)
    return csv_dir, p1_dir, p2_dir, p3_dir

# ===================== Graph & Scheduling (Efficient) =====================
def build_graph(N, edges):
    adj = defaultdict(list)
    indeg = [0]*N
    pred = defaultdict(list)
    for u, v in edges:
        if 0 <= u < N and 0 <= v < N:
            adj[u].append(v)
            pred[v].append(u)
            indeg[v] += 1
    return adj, indeg, pred

def compute_longest_path_to_sink(N, adj, w):
    radj = defaultdict(list)
    outdeg = [0]*N
    for u in range(N):
        for v in adj[u]:
            radj[v].append(u)
            outdeg[u] += 1
    q = deque([u for u in range(N) if outdeg[u] == 0])
    dist = [0]*N
    while q:
        v = q.popleft()
        for u in radj[v]:
            nv = dist[v] + w[v]
            if dist[u] < nv:
                dist[u] = nv
            outdeg[u] -= 1
            if outdeg[u] == 0:
                q.append(u)
    for v in range(N):
        dist[v] += w[v]
    return dist

def compute_es_ls_slack(N, adj, pred, cycles):
    indeg0 = [0]*N
    for v in range(N):
        for u in pred[v]:
            indeg0[v] += 1
    q = deque([v for v in range(N) if indeg0[v] == 0])
    topo = []
    while q:
        v = q.popleft()
        topo.append(v)
        for w in adj[v]:
            indeg0[w] -= 1
            if indeg0[w] == 0:
                q.append(w)

    ES = [0]*N
    EF = [0]*N
    for v in topo:
        es = 0
        for u in pred[v]:
            if EF[u] > es:
                es = EF[u]
        ES[v] = es
        dur = cycles.get(v, 0)
        if dur < 0: dur = 0
        EF[v] = es + dur

    EF_max = max(EF) if EF else 0
    LF = [EF_max]*N; LS = [0]*N
    for v in reversed(topo):
        if adj[v]:
            LF[v] = min(LS[w] for w in adj[v])
        dur = cycles.get(v, 0)
        if dur < 0: dur = 0
        LS[v] = LF[v] - dur
    slack = [LS[v] - ES[v] for v in range(N)]
    return ES, LS, slack, EF_max

def evaluate_makespan(order, adj, indeg0, pred, pipe_of, cycles):
    N = len(indeg0)
    if N == 0:
        return 0, [], [], defaultdict(list)
    S = [0]*N; E = [0]*N; pipe_ready = defaultdict(int)
    for v in order:
        if not (0 <= v < N): continue
        pr = pipe_of.get(v)
        start_ready = pipe_ready[pr] if pr else 0
        dep_ready = 0
        for u in pred[v]:
            if E[u] > dep_ready:
                dep_ready = E[u]
        start = start_ready if start_ready > dep_ready else dep_ready
        dur = cycles.get(v, 0);  dur = max(0, dur)
        S[v] = start; E[v] = start + dur
        if pr:
            pipe_ready[pr] = E[v]
    total = max(E) if E else 0
    per_pipe_segments = defaultdict(list)
    for v in order:
        if 0 <= v < N:
            p = pipe_of.get(v) or "NONE"
            d = cycles.get(v, 0)
            if d > 0:
                per_pipe_segments[p].append((S[v], d, v))
    return total, S, E, per_pipe_segments

def q3_reschedule_fast(N, edges, base_order, pipe_of, cycles):
    """
    Fast enhanced list scheduling (single-pass, O((N+E) log N))
    Score = (finish_time, -crit, slack, gap_penalty, -down_gain, base_idx)
    """
    adj, indeg, pred = build_graph(N, edges)
    w = [cycles.get(i, 0) for i in range(N)]
    crit = compute_longest_path_to_sink(N, adj, w)
    ES, LS, slack, _ = compute_es_ls_slack(N, adj, pred, cycles)

    # Downstream criticality: average of immediate successors
    down_gain = [0]*N
    for v in range(N):
        if adj[v]:
            s = 0
            for wv in adj[v]:
                s += crit[wv]
            down_gain[v] = s / len(adj[v])

    indeg_now = indeg[:]
    ready = {v for v in range(N) if indeg_now[v] == 0}
    base_pos = {v: i for i, v in enumerate(base_order) if 0 <= v < N}
    pipe_ready = defaultdict(int)

    S = [0]*N; E = [0]*N
    schedule = []; picked = set()

    def push_all(h):
        for v in list(ready):
            if v in picked: continue
            dep_ready = 0
            for u in pred[v]:
                if E[u] > dep_ready:
                    dep_ready = E[u]
            pr = pipe_of.get(v)
            pr_ready = pipe_ready[pr] if pr else 0
            start = pr_ready if pr_ready > dep_ready else dep_ready
            dur = max(0, cycles.get(v, 0))
            finish = start + dur
            node_slack = slack[v] if v < len(slack) else 0
            gap_penalty = max(0, start - dep_ready)
            dg = down_gain[v]
            heapq.heappush(
                h,
                (finish, -crit[v], node_slack, gap_penalty, -dg, base_pos.get(v, 1<<30),
                 start, v, pr, dur)
            )

    while ready:
        heap = []
        push_all(heap)
        if not heap: break
        _, _, _, _, _, _, start, v, pr, dur = heapq.heappop(heap)
        schedule.append(v); picked.add(v)
        S[v] = start; E[v] = start + dur
        if pr: pipe_ready[pr] = E[v]
        ready.discard(v)
        for wv in adj[v]:
            indeg_now[wv] -= 1
            if indeg_now[wv] == 0:
                ready.add(wv)

    extra = [x for x in base_order if (x < 0 or x >= N)]
    return schedule + extra, pred, crit

# -------- Critical-Path Adjacent Fine-Tuning (Limited Attempts) --------
def extract_critical_chain_by_E(E, pred):
    """Approximate critical chain by backtracking from max-E node, always choosing predecessor with highest E."""
    if not E: return []
    sink = max(range(len(E)), key=lambda i: E[i])
    chain = [sink]
    v = sink
    seen = set(chain)
    for _ in range(len(E)):
        ps = pred[v]
        if not ps: break
        u = max(ps, key=lambda x: E[x])
        if u in seen: break
        chain.append(u); seen.add(u); v = u
    chain.reverse()
    return chain  # Source-to-sink order

def improve_by_cp_adjacent_swaps(order, adj, indeg, pred, pipe_of, cycles,
                                 passes=2, try_limit=256):
    """
    Perform only adjacent swaps along the current critical chain (topologically valid).
    - Each pass: extract critical chain → try swapping each node with its predecessor (if no edge a→b)
    - Accept if makespan improves
    - Strict attempt limit ensures minimal overhead
    """
    if not order: return order
    N = len(indeg)
    edge_direct = set()
    for u in range(N):
        for v in adj[u]:
            edge_direct.add((u, v))

    def legal_swap(a, b):
        return (a, b) not in edge_direct  # Swap illegal if a→b exists

    best_order = order[:]
    best_T, _, E, _ = evaluate_makespan(best_order, adj, indeg, pred, pipe_of, cycles)

    for _ in range(max(0, passes)):
        tries = 0
        chain = extract_critical_chain_by_E(E, pred)
        if not chain or tries >= try_limit:
            break

        pos = {v: i for i, v in enumerate(best_order)}
        improved_round = False

        for node in chain:
            if tries >= try_limit: break
            j = pos.get(node, None)
            if j is None or j == 0:
                continue
            a = best_order[j-1]; b = best_order[j]
            if not (0 <= a < N and 0 <= b < N):
                continue
            if not legal_swap(a, b):
                continue
            # Try swap
            best_order[j-1], best_order[j] = b, a
            T2, _, E2, _ = evaluate_makespan(best_order, adj, indeg, pred, pipe_of, cycles)
            tries += 1
            if T2 < best_T:
                best_T = T2; E = E2
                pos[a], pos[b] = j, j-1
                improved_round = True
            else:
                best_order[j-1], best_order[j] = a, b

        if not improved_round:
            break
    return best_order

def write_problem3(p3_dir: Path, task: str, schedule, memory_map, spill_ops):
    p3_dir.mkdir(parents=True, exist_ok=True)
    with open(p3_dir / f"{task}_schedule.txt", "w", encoding="utf-8") as f:
        for x in schedule: f.write(str(x) + "\n")
    with open(p3_dir / f"{task}_memory.txt", "w", encoding="utf-8") as f:
        for b in sorted(memory_map.keys()):
            f.write(f"{b}:{memory_map[b]}\n")
    with open(p3_dir / f"{task}_spill.txt", "w", encoding="utf-8") as f:
        for b, off in spill_ops: f.write(f"{b}:{off}\n")

# ===================== Plotting =====================
class MplCanvas(FigureCanvas):
    def __init__(self, width=8.4, height=5.6, dpi=100):
        self.fig = Figure(figsize=(width, height), dpi=dpi)
        super().__init__(self.fig)

def pipe_utilization(segments, makespan):
    util = {}
    for p, segs in segments.items():
        busy = sum(d for _, d, _ in segs)
        util[p] = (busy / makespan) if makespan > 0 else 0.0
    return util

def sort_pipes_for_plot(seg_dict):
    pipes = list(seg_dict.keys())
    def key(p):
        return (1, p) if p == "NONE" else (0, p or "")
    return sorted(pipes, key=key)

def draw_gantt(ax, segments, title="Gantt Chart", highlight_nodes=None, show_all_node_labels=False):
    pipes = sort_pipes_for_plot(segments)
    if not pipes:
        ax.text(0.5, 0.5, "No tasks", ha="center", va="center", transform=ax.transAxes)
        ax.set_axis_off(); return
    yticks, yticklabels = [], []
    pipe_y_map = {pipe: i for i, pipe in enumerate(pipes)}
    max_t = 0
    for p in pipes:
        y = pipe_y_map[p]
        spans = sorted(segments[p], key=lambda x: x[0])
        for (st, du, nid) in spans:
            ax.broken_barh([(st, du)], (y - 0.4, 0.8), edgecolor="black", linewidth=0.5)
            if (highlight_nodes and (nid in highlight_nodes)) or \
               (show_all_node_labels and du > 0.01 * (ax.get_xlim()[1] if ax.get_xlim()[1] > 0 else 1)):
                ax.text(st + du / 2, y, f"{'*' if highlight_nodes and (nid in highlight_nodes) else ''}{nid}",
                        ha="center", va="center", fontsize=7)
            max_t = max(max_t, st + du)
        yticks.append(y); yticklabels.append(p)
    ax.set_yticks(yticks); ax.set_yticklabels(yticklabels)
    ax.set_ylim(-1, len(pipes)); ax.set_xlim(0, max_t if max_t > 0 else 1)
    ax.set_xlabel("Cycle"); ax.set_title(title)
    ax.grid(True, axis="x", linestyle="--", linewidth=0.5, alpha=0.6)

def draw_util(ax, util_dict_p2, util_dict_p3, title="Pipeline Utilization"):
    pipes = sorted(list(set(util_dict_p2.keys()) | set(util_dict_p3.keys())))
    if not pipes:
        ax.text(0.5,0.5,"No pipelines", ha="center", va="center", transform=ax.transAxes)
        ax.set_axis_off(); return
    x = list(range(len(pipes))); w = 0.35
    p2_vals = [util_dict_p2.get(p,0.0) for p in pipes]
    p3_vals = [util_dict_p3.get(p,0.0) for p in pipes]
    ax.bar([i-w/2 for i in x], p2_vals, width=w, label="P2")
    ax.bar([i+w/2 for i in x], p3_vals, width=w, label="P3")
    ax.set_xticks(x); ax.set_xticklabels(pipes, rotation=0)
    ax.set_ylim(0, 1.05); ax.set_ylabel("Utilization"); ax.set_title(title)
    ax.grid(True, axis="y", linestyle="--", linewidth=0.5, alpha=0.6); ax.legend()

def draw_time_compare(ax, p2, p3, task_name="Task"):
    ax.bar([0], [p2], width=0.6, label="P2")
    ax.bar([1], [p3], width=0.6, label="P3")
    ax.set_xticks([0,1]); ax.set_xticklabels(["P2","P3"])
    ax.set_ylabel("Cycles"); ax.set_title(f"{task_name} - Total Time Comparison")
    ax.grid(True, axis="y", linestyle="--", linewidth=0.5, alpha=0.6); ax.legend()

def draw_criticality(ax, crit_values, title="Node Criticality (Distance to Sink)"):
    if not crit_values:
        ax.text(0.5,0.5,"No nodes", ha="center", va="center", transform=ax.transAxes)
        ax.set_axis_off(); return
    valid_nodes = [i for i, val in enumerate(crit_values) if val > 0]
    valid_crit = [crit_values[i] for i in valid_nodes]
    if not valid_nodes:
        ax.text(0.5, 0.5, "No critical-path nodes", ha="center", va="center", transform=ax.transAxes)
        ax.set_axis_off(); return
    ax.bar(valid_nodes, valid_crit)
    ax.set_xlabel("Node ID"); ax.set_ylabel("Criticality"); ax.set_title(title)
    ax.grid(True, axis="y", linestyle="--", linewidth=0.5, alpha=0.6)
    if len(valid_nodes) < 50: ax.set_xticks(valid_nodes)
    else: ax.set_xticks([])

# ===================== Task Solver Wrapper =====================
def solve_task_q3(root: Path, task: str):
    csv_dir, p1_dir, p2_dir, p3_dir = ensure_dirs(root)
    meta = read_nodes_edges(csv_dir, task)
    N, edges = meta["N"], meta["edges"]
    pipe_of, cycles = meta["op_pipe"], meta["cycles"]

    p2_schedule = read_problem_schedule(p2_dir, task, "2")
    p1_schedule = read_problem_schedule(p1_dir, task, "1")
    base_schedule = p2_schedule if p2_schedule else (p1_schedule if p1_schedule else list(range(N)))

    mem_map = read_memory(p2_dir, task)
    spill_ops = read_spill(p2_dir, task)

    adj, indeg, pred = build_graph(N, edges)
    base_order = [x for x in base_schedule if 0 <= x < N]
    p2_time, S2, E2, seg2 = evaluate_makespan(base_order, adj, indeg, pred, pipe_of, cycles)

    # 1) Fast enhanced list scheduling (single pass)
    new_order, pred_used, crit = q3_reschedule_fast(N, edges, base_order, pipe_of, cycles)

    # 2) Critical-path adjacent fine-tuning (limited attempts)
    improved_order = improve_by_cp_adjacent_swaps(
        [x for x in new_order if 0 <= x < N],
        adj, indeg, pred_used,
        pipe_of, cycles,
        passes=CP_LOCAL_PASSES,
        try_limit=CP_LOCAL_TRY_LIMIT
    )

    # Evaluate and write back
    p3_time, S3, E3, seg3 = evaluate_makespan(improved_order, adj, indeg, pred_used, pipe_of, cycles)
    final_schedule = improved_order + [x for x in new_order if (x < 0 or x >= N)]
    write_problem3(p3_dir, task, final_schedule, mem_map, spill_ops)

    return {
        "task": task,
        "p2_time": int(p2_time),
        "p3_time": int(p3_time),
        "improve": int(p2_time) - int(p3_time),
        "p3_schedule_len": len(final_schedule),
        "p2_schedule_len": len(base_schedule) if base_schedule else 0,
        "p2_used": True if p2_schedule else False,
        "p3_dir": str(p3_dir),

        "N": N,
        "pipes": sorted(list(set(p for p in pipe_of.values() if p is not None))),
        "seg_p2": seg2,
        "seg_p3": seg3,
        "S2": S2, "E2": E2,
        "S3": S3, "E3": E3,
        "pipe_of": pipe_of,
        "cycles": cycles,
        "crit": crit,
        "order_p2": base_order,
        "order_p3": [x for x in final_schedule if 0 <= x < N],
    }

def build_typora_md(results):
    md = []
    md.append("## Problem 3: Reduce Total Runtime Without Significantly Increasing Data Movement (Balanced Version)\n")
    md.append("### Methodology\n")
    md.append("- Single enhanced list scheduling: minimize finish time using criticality/slack/downstream gain/gap penalty\n")
    md.append("- Critical-path adjacent fine-tuning: limited local swaps along critical chain, very low overhead\n")
    md.append("\n### Runtime Comparison Across Six Cases (Unit: Cycles)\n")
    md.append("| Task | Baseline Uses P2 | P2 Time | P3 Time | Improvement | P2 Seq Len | P3 Seq Len |\n|---|---:|---:|---:|---:|---:|---:|\n")
    for r in results:
        md.append(f"| {r['task']} | {'Yes' if r['p2_used'] else 'No'} | {r['p2_time']} | {r['p3_time']} | {r['improve']} | {r['p2_schedule_len']} | {r['p3_schedule_len']} |\n")
    md.append("\n> Note: Outputs are in `Problem3/`, reusing Problem2's memory/spill assignments without adding new spills.\n")
    return "".join(md)

# ===================== Export Thread =====================
class PlotExporterWorker(QObject):
    finished = pyqtSignal()
    error = pyqtSignal(str)
    progress = pyqtSignal(str)

    def __init__(self, results, output_base_dir):
        super().__init__()
        self.results = results
        self.output_base_dir = output_base_dir

    def run(self):
        try:
            self.output_base_dir.mkdir(parents=True, exist_ok=True)
            for i, r in enumerate(self.results):
                self.progress.emit(f"Task {r['task']}: Exporting plots ({i+1}/{len(self.results)})...")
                task_out_dir = self.output_base_dir / r["task"]
                task_out_dir.mkdir(parents=True, exist_ok=True)

                crit = r["crit"]
                non_zero_crit_nodes = [i_node for i_node, c in enumerate(crit) if c > 0]
                K = min(10, len(non_zero_crit_nodes))
                top_idx = set(sorted(non_zero_crit_nodes, key=lambda i_node: crit[i_node], reverse=True)[:K])

                # P2 Gantt
                try:
                    fig = Figure(figsize=(12, 6), dpi=120)
                    ax = fig.add_subplot(111)
                    draw_gantt(ax, r["seg_p2"], title=f"{r['task']} - P2 Gantt Chart", highlight_nodes=top_idx, show_all_node_labels=False)
                    fig.tight_layout(); fig.savefig(task_out_dir / f"{r['task']}_P2_Gantt.png")
                except Exception as plot_err:
                    self.progress.emit(f"  [Warning] Failed to export {r['task']} P2 Gantt: {plot_err}")
                finally:
                    plt.close('all')

                # P3 Gantt
                try:
                    fig = Figure(figsize=(12, 6), dpi=120)
                    ax = fig.add_subplot(111)
                    draw_gantt(ax, r["seg_p3"], title=f"{r['task']} - P3 Gantt Chart", highlight_nodes=top_idx, show_all_node_labels=False)
                    fig.tight_layout(); fig.savefig(task_out_dir / f"{r['task']}_P3_Gantt.png")
                except Exception as plot_err:
                    self.progress.emit(f"  [Warning] Failed to export {r['task']} P3 Gantt: {plot_err}")
                finally:
                    plt.close('all')

                # Utilization
                try:
                    makespan_p2 = max(r["E2"]) if r["E2"] else 1
                    makespan_p3 = max(r["E3"]) if r["E3"] else 1
                    util2 = pipe_utilization(r["seg_p2"], makespan_p2)
                    util3 = pipe_utilization(r["seg_p3"], makespan_p3)
                    fig = Figure(figsize=(8, 5), dpi=120)
                    ax = fig.add_subplot(111)
                    draw_util(ax, util2, util3, title=f"{r['task']} - Pipeline Utilization")
                    fig.tight_layout(); fig.savefig(task_out_dir / f"{r['task']}_Utilization.png")
                except Exception as plot_err:
                    self.progress.emit(f"  [Warning] Failed to export {r['task']} utilization plot: {plot_err}")
                finally:
                    plt.close('all')

                # Time Comparison
                try:
                    fig = Figure(figsize=(6, 4), dpi=120)
                    ax = fig.add_subplot(111)
                    draw_time_compare(ax, r["p2_time"], r["p3_time"], r["task"])
                    fig.tight_layout(); fig.savefig(task_out_dir / f"{r['task']}_TimeCompare.png")
                except Exception as plot_err:
                    self.progress.emit(f"  [Warning] Failed to export {r['task']} time comparison: {plot_err}")
                finally:
                    plt.close('all')

                # Criticality
                try:
                    fig = Figure(figsize=(12, 5), dpi=120)
                    ax = fig.add_subplot(111)
                    draw_criticality(ax, crit, title=f"{r['task']} - Node Criticality (Distance to Sink)")
                    fig.tight_layout(); fig.savefig(task_out_dir / f"{r['task']}_Criticality.png")
                except Exception as plot_err:
                    self.progress.emit(f"  [Warning] Failed to export {r['task']} criticality plot: {plot_err}")
                finally:
                    plt.close('all')

            self.finished.emit()
        except Exception as e:
            self.error.emit(f"Error during plot export: {e}")
        finally:
            self.progress.emit("All plot exports completed.")

# ===================== Main Window =====================
class MplCanvas(FigureCanvas):
    def __init__(self, width=8.4, height=5.6, dpi=100):
        self.fig = Figure(figsize=(width, height), dpi=dpi)
        super().__init__(self.fig)

class MainWin(QMainWindow):
    def __init__(self):
        super().__init__()
        self.setWindowTitle(APP_TITLE); self.resize(1200, 760)

        w = QWidget(); self.setCentralWidget(w)
        main = QVBoxLayout(w)

        top = QHBoxLayout()
        self.btn_pick = QPushButton("Select Root Directory (must contain CSV/, Problem1/, Problem2/)")
        self.btn_run = QPushButton("Run Problem 3 (All Cases)")
        self.btn_copy = QPushButton("Copy Report Snippet (Typora)")
        self.btn_export = QPushButton("Export All Plots (PNG)")
        top.addWidget(self.btn_pick); top.addWidget(self.btn_run)
        top.addStretch(1); top.addWidget(self.btn_copy); top.addWidget(self.btn_export)
        main.addLayout(top)

        self.lbl = QLabel("Root Directory: Not Selected"); main.addWidget(self.lbl)

        splitter = QSplitter()
        left = QWidget(); left_layout = QVBoxLayout(left)

        self.table = QTableWidget(0, 7)
        self.table.setHorizontalHeaderLabels(["Task","Uses P2","P2 Time","P3 Time","Improve","P2 Len","P3 Len"])
        self.table.horizontalHeader().setStretchLastSection(True)
        left_layout.addWidget(self.table)

        self.log = QTextEdit(); self.log.setReadOnly(True); left_layout.addWidget(self.log)
        splitter.addWidget(left)

        right = QWidget(); right_layout = QVBoxLayout(right)
        opt_line = QHBoxLayout()
        self.cmb_which = QComboBox()
        self.cmb_which.addItems(["Gantt: P2","Gantt: P3","Utilization P2 vs P3","Time Comparison","Criticality (Approx CP)"])
        self.btn_draw = QPushButton("Visualize Selected Task")
        opt_line.addWidget(self.cmb_which); opt_line.addStretch(1); opt_line.addWidget(self.btn_draw)
        right_layout.addLayout(opt_line)

        self.canvas = MplCanvas(width=8.4, height=5.6, dpi=100)
        if QT_LIB == "PyQt6":
            self.canvas.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
        else:
            self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
        right_layout.addWidget(self.canvas)

        splitter.addWidget(right); splitter.setStretchFactor(0, 2); splitter.setStretchFactor(1, 3)
        main.addWidget(splitter)

        self.root = Path(os.getcwd())
        self.csv_dir = self.p1_dir = self.p2_dir = self.p3_dir = None
        self.results = []

        self.exporter_thread = None; self.exporter_worker = None

        self.btn_pick.clicked.connect(self.on_pick)
        self.btn_run.clicked.connect(self.on_run)
        self.btn_copy.clicked.connect(self.on_copy)
        self.btn_export.clicked.connect(self.on_export_all)
        self.btn_draw.clicked.connect(self.on_draw_selected)
        self.table.cellClicked.connect(self.on_table_clicked)

        self.auto_locate()

    def auto_locate(self):
        try:
            csv_dir, p1_dir, p2_dir, p3_dir = ensure_dirs(self.root)
            self.csv_dir, self.p1_dir, self.p2_dir, self.p3_dir = csv_dir, p1_dir, p2_dir, p3_dir
            self.lbl.setText(f"Root: {self.root}\n  CSV: {csv_dir.name}\n  P1: {p1_dir.name}\n  P2: {p2_dir.name}\n  P3(Output): {p3_dir.name}")
            self.log.append(f"[Info] Auto-detected root directory: {self.root}")
        except FileNotFoundError as e:
            self.csv_dir = self.p1_dir = self.p2_dir = self.p3_dir = None
            self.lbl.setText("Root Directory: Not Selected (Current working dir invalid)")
            self.log.append(f"[Warning] Auto-detection failed: {e}\nPlease manually select a valid root directory.")
        except Exception as e:
            self.csv_dir = self.p1_dir = self.p2_dir = self.p3_dir = None
            self.lbl.setText("Root Directory: Not Selected (Error)")
            self.log.append(f"[Error] Unexpected error during auto-detection: {e}")

    def on_pick(self):
        try:
            initial_path = str(self.root) if self.root.exists() else os.getcwd()
            d = QFileDialog.getExistingDirectory(self, "Select Root Directory (must contain CSV/, Problem1/, Problem2/)", initial_path)
            if d:
                selected_root = Path(d)
                csv_dir, p1_dir, p2_dir, p3_dir = ensure_dirs(selected_root)
                self.root = selected_root
                self.csv_dir, self.p1_dir, self.p2_dir, self.p3_dir = csv_dir, p1_dir, p2_dir, p3_dir
                self.lbl.setText(f"Root: {self.root}\n  CSV: {csv_dir.name}\n  P1: {p1_dir.name}\n  P2: {p2_dir.name}\n  P3(Output): {p3_dir.name}")
                self.log.append(f"[Info] Manually selected root: {self.root}")
                self.table.setRowCount(0); self.results = []
                self.canvas.fig.clf(); self.canvas.draw_idle()
            else:
                self.log.append("[Info] Directory selection canceled.")
        except FileNotFoundError as e:
            QMessageBox.critical(self, "Error", f"Invalid directory:\n{e}\nEnsure it contains all required subdirectories.")
            self.log.append(f"[Error] Directory validation failed: {e}")
            self.csv_dir = self.p1_dir = self.p2_dir = self.p3_dir = None
            self.lbl.setText("Root Directory: Invalid Selection")
        except Exception as e:
            QMessageBox.critical(self, "Error", f"Unexpected error during selection:\n{e}")
            self.log.append(f"[Error] Selection error: {e}")

    def on_run(self):
        if not (self.csv_dir and self.p1_dir and self.p2_dir and self.p3_dir):
            QMessageBox.warning(self, "Notice", "Please select and validate a root directory first.")
            return
        self.table.setRowCount(0); self.results = []; total_imp = 0
        self._set_ui_enabled_state(False)
        try:
            for t in TASKS:
                self.log.append(f"Processing task: {t}...")
                QApplication.processEvents()
                r = solve_task_q3(self.root, t)
                self.results.append(r); self.append_row(r)
                self.log.append(f"[Done] {t}: P2={r['p2_time']}, P3={r['p3_time']}, Improve={r['improve']}")
                QApplication.processEvents()
                total_imp += r["improve"]
            self.log.append(f"\n[Summary] Total improvement: {total_imp} cycles. Output: {self.p3_dir}")
            QMessageBox.information(self, "Complete", f"All cases finished! Total improvement: {total_imp} cycles.\nOutput to: {self.p3_dir}")
        except Exception as e:
            QMessageBox.critical(self, "Error", f"Execution failed:\n{e}")
            self.log.append(f"[Error] Runtime error: {e}")
        finally:
            self._set_ui_enabled_state(True)

    def append_row(self, r):
        row = self.table.rowCount(); self.table.insertRow(row)
        def setc(c, text): self.table.setItem(row, c, QTableWidgetItem(str(text)))
        setc(0, r["task"])
        setc(1, "Yes" if r["p2_used"] else "No")
        setc(2, r["p2_time"])
        setc(3, r["p3_time"])
        setc(4, r["improve"])
        setc(5, r["p2_schedule_len"])
        setc(6, r["p3_schedule_len"])

    def on_copy(self):
        if not self.results:
            QMessageBox.information(self, "Notice", "Run Problem 3 first to generate results.")
            return
        md = build_typora_md(self.results)
        cb = QApplication.clipboard(); cb.setText(md)
        QMessageBox.information(self, "Copied", "Report snippet (Typora Markdown) copied to clipboard ✅")

    def current_task_result(self):
        row = self.table.currentRow()
        if row < 0 or row >= len(self.results): return None
        return self.results[row]

    def on_table_clicked(self, r, c):
        self.cmb_which.setCurrentIndex(1); self.on_draw_selected()

    def on_draw_selected(self):
        if not self.results:
            QMessageBox.information(self, "Notice", "Run Problem 3 first."); return
        r = self.current_task_result()
        if not r:
            QMessageBox.information(self, "Notice", "Select a task row first."); return

        which = self.cmb_which.currentText()
        self.canvas.fig.clf(); ax = self.canvas.fig.add_subplot(111)

        crit = r["crit"]
        non_zero_crit_nodes = [i for i, c in enumerate(crit) if c > 0]
        K = min(10, len(non_zero_crit_nodes))
        top_idx = set(sorted(non_zero_crit_nodes, key=lambda i: crit[i], reverse=True)[:K])

        if which == "Gantt: P2":
            draw_gantt(ax, r["seg_p2"], title=f"{r['task']} - P2 Gantt Chart", highlight_nodes=top_idx, show_all_node_labels=True)
        elif which == "Gantt: P3":
            draw_gantt(ax, r["seg_p3"], title=f"{r['task']} - P3 Gantt Chart", highlight_nodes=top_idx, show_all_node_labels=True)
        elif which == "Utilization P2 vs P3":
            makespan_p2 = max(r["E2"]) if r["E2"] else 1
            makespan_p3 = max(r["E3"]) if r["E3"] else 1
            util2 = pipe_utilization(r["seg_p2"], makespan_p2)
            util3 = pipe_utilization(r["seg_p3"], makespan_p3)
            draw_util(ax, util2, util3, title=f"{r['task']} - Pipeline Utilization")
        elif which == "Time Comparison":
            draw_time_compare(ax, r["p2_time"], r["p3_time"], r["task"])
        else:
            draw_criticality(ax, crit, title=f"{r['task']} - Node Criticality (Distance to Sink)")

        self.canvas.fig.tight_layout(); self.canvas.draw_idle()

    def on_export_all(self):
        if not self.results:
            QMessageBox.information(self, "Notice", "Run Problem 3 first."); return
        self._set_ui_enabled_state(False)
        self.log.append("Starting full plot export... This may take a while.")
        QApplication.processEvents()

        self.exporter_thread = QThread()
        output_dir = Path(self.p3_dir) / "Figures"
        self.exporter_worker = PlotExporterWorker(self.results, output_dir)
        self.exporter_worker.moveToThread(self.exporter_thread)

        self.exporter_thread.started.connect(self.exporter_worker.run)
        self.exporter_worker.finished.connect(self._export_finished)
        self.exporter_worker.error.connect(self._export_error)
        self.exporter_worker.progress.connect(self.log.append)

        self.exporter_thread.start()

    def _export_finished(self):
        QMessageBox.information(self, "Complete", f"All plots exported!\nLocation: {Path(self.p3_dir) / 'Figures'}")
        self.log.append(f"[Done] Plot export completed.")
        self._cleanup_exporter(); self._set_ui_enabled_state(True)

    def _export_error(self, message):
        QMessageBox.critical(self, "Error", f"Plot export failed:\n{message}")
        self.log.append(f"[Error] Export error: {message}")
        self._cleanup_exporter(); self._set_ui_enabled_state(True)

    def _cleanup_exporter(self):
        if self.exporter_worker:
            self.exporter_worker.deleteLater(); self.exporter_worker = None
        if self.exporter_thread:
            self.exporter_thread.quit(); self.exporter_thread.wait()
            self.exporter_thread.deleteLater(); self.exporter_thread = None

    def _set_ui_enabled_state(self, enabled: bool):
        self.btn_pick.setEnabled(enabled); self.btn_run.setEnabled(enabled)
        self.btn_copy.setEnabled(enabled); self.btn_export.setEnabled(enabled)
        self.btn_draw.setEnabled(enabled)

    def closeEvent(self, e):
        try: self._cleanup_exporter()
        finally: super().closeEvent(e)

# ===================== Entry Point =====================
def main():
    try:
        try:
            if QT_LIB == "PyQt5": matplotlib.use("Qt5Agg", force=True)
            else: matplotlib.use("QtAgg", force=True)
        except Exception:
            pass
        app = QApplication(sys.argv)
        win = MainWin(); win.show()
        try: sys.exit(app.exec_())
        except AttributeError: sys.exit(app.exec())
    except Exception as e:
        print(f"[FATAL] Application startup failed: {e}")

if __name__ == "__main__":
    main()