In [None]:
import os
import json
import pandas as pd
import bisect
from typing import Dict, List, Tuple, Optional, Set

# --- Configuration ---
DATA_DIR = r"./CSV版本"
PROBLEM1_DIR = os.path.join(DATA_DIR, "Attachment", "Problem1")
OUT_DIR = os.path.join(DATA_DIR, "Attachment", "Problem2")
os.makedirs(OUT_DIR, exist_ok=True)

# Hardware cache capacity limits
HARDWARE_CACHE_LIMITS = {
    'L1': 4096,
    'UB': 1024,
    'L0A': 256,
    'L0B': 256,
    'L0C': 512,
}

# --- Data Structures and Core Algorithms (no functional changes) ---
class MemoryManager:
    def __init__(self, capacity: int, cache_type: str):
        self.capacity = capacity
        self.cache_type = cache_type
        self.free_blocks: List[Tuple[int, int]] = [(0, capacity)]
        self.allocated_blocks: Dict[int, Tuple[int, int]] = {}
        self.last_access_time: Dict[int, int] = {}
        self.timestamp = 0

    def _get_next_timestamp(self) -> int:
        self.timestamp += 1
        return self.timestamp

    def allocate(self, buf_id: int, size: int) -> Optional[int]:
        if size > self.capacity:
            return None
        best_idx = -1
        best_size = float('inf')
        for i, (block_start, block_size) in enumerate(self.free_blocks):
            if block_size >= size and block_size < best_size:
                best_size = block_size
                best_idx = i
                if best_size == size:
                    break
        if best_idx != -1:
            block_start, block_size = self.free_blocks.pop(best_idx)
            self.allocated_blocks[buf_id] = (block_start, size)
            self.last_access_time[buf_id] = self._get_next_timestamp()
            remaining_size = block_size - size
            if remaining_size > 0:
                bisect.insort(self.free_blocks, (block_start + size, remaining_size))
            return block_start
        return None

    def free(self, buf_id: int):
        if buf_id not in self.allocated_blocks:
            return
        addr, size = self.allocated_blocks.pop(buf_id)
        self.last_access_time.pop(buf_id, None)
        new_free_block = (addr, size)
        idx = bisect.bisect_left(self.free_blocks, new_free_block)
        if idx > 0 and self.free_blocks[idx-1][0] + self.free_blocks[idx-1][1] == addr:
            prev_addr, prev_size = self.free_blocks.pop(idx-1)
            addr, size = prev_addr, prev_size + size
            new_free_block = (addr, size)
        if idx < len(self.free_blocks) and addr + size == self.free_blocks[idx][0]:
            next_addr, next_size = self.free_blocks.pop(idx)
            size += next_size
            new_free_block = (addr, size)
        bisect.insort(self.free_blocks, new_free_block)

    def touch(self, buf_id: int):
        if buf_id in self.allocated_blocks:
            self.last_access_time[buf_id] = self._get_next_timestamp()

    def select_victim_for_spill(self) -> Optional[int]:
        if not self.allocated_blocks:
            return None
        return min(self.last_access_time.keys(), key=lambda k: self.last_access_time[k])
    
    def get_allocated_block_info(self, buf_id: int) -> Optional[Tuple[int, int]]:
        return self.allocated_blocks.get(buf_id, None)
    
    def get_total_free_space(self) -> int:
        return sum(size for _, size in self.free_blocks)
    
    def get_largest_free_block_size(self) -> int:
        if not self.free_blocks:
            return 0
        return max(size for _, size in self.free_blocks)
    
    def clear_all_blocks(self) -> List[int]:
        spilled_bufs = list(self.allocated_blocks.keys())
        for buf_id in spilled_bufs:
            self.free(buf_id)
        return spilled_bufs
    
    def get_current_allocations(self) -> Dict[int, int]:
        return {buf_id: addr for buf_id, (addr, _) in self.allocated_blocks.items()}

# --- SPILL Node Generation (no changes) ---
def create_spill_nodes(base_id: int, buf_id: int, size: int, cache_type: str) -> Tuple[dict, dict]:
    spill_out_node = {
        "Id": base_id,
        "Op": "SPILL_OUT",
        "Pipe": "MTE3",
        "Cycles": size * 2 + 150,
        "Bufs": [buf_id],
        "BufId": buf_id,
        "Size": size,
        "Type": cache_type
    }
    spill_in_node = {
        "Id": base_id + 1,
        "Op": "SPILL_IN",
        "Pipe": "MTE2",
        "Cycles": size * 2 + 150,
        "Bufs": [buf_id],
        "BufId": buf_id,
        "Size": size,
        "Type": cache_type
    }
    return spill_out_node, spill_in_node

# --- Main Scheduling and Allocation Logic (no functional changes) ---
def process_schedule_with_memory_management(df_nodes: pd.DataFrame, schedule: List[int]) -> Tuple[List[int], Dict[int, int], List[str], int]:
    node_index = {int(row['Id']): row for _, row in df_nodes.iterrows()}
    buf_info_map: Dict[int, Tuple[int, str]] = {} 
    for _, row in df_nodes.iterrows():
        if row['Op'] == 'ALLOC':
            buf_id = int(row['BufId'])
            size = int(row['Size']) if not pd.isna(row['Size']) else 0
            cache_type = row['Type'] if not pd.isna(row['Type']) else 'L1'
            buf_info_map[buf_id] = (size, cache_type)

    mem_managers: Dict[str, MemoryManager] = {}
    for cache_type, capacity in HARDWARE_CACHE_LIMITS.items():
        mem_managers[cache_type] = MemoryManager(capacity, cache_type)
    
    final_schedule: List[int] = []
    memory_allocation_history: Dict[int, int] = {} 
    spill_log: List[str] = []             
    total_extra_data_movement = 0
    
    new_node_id_counter = max(node_index.keys()) + 1
    new_nodes: Dict[int, dict] = {}

    for nid in schedule:
        current_node = node_index[nid]
        op = current_node['Op']
        
        if op == 'ALLOC':
            try:
                buf_id = int(current_node['BufId'])
                if buf_id not in buf_info_map:
                    final_schedule.append(nid)
                    continue
                    
                size, cache_type = buf_info_map[buf_id]
                if cache_type not in mem_managers:
                    final_schedule.append(nid)
                    continue

                manager = mem_managers[cache_type]
                if size > manager.capacity:
                    raise RuntimeError(f"Buffer {buf_id} requested size {size} exceeds total capacity {manager.capacity} of cache {cache_type}")

                offset = manager.allocate(buf_id, size)
                if offset is not None:
                    memory_allocation_history[buf_id] = offset
                    final_schedule.append(nid)
                else:
                    spilled_ids: Set[int] = set()
                    spill_successful = False
                    
                    while True:
                        total_free = manager.get_total_free_space()
                        if total_free < size:
                            all_spilled = manager.clear_all_blocks()
                            for spilled_id in all_spilled:
                                if spilled_id in memory_allocation_history:
                                    spilled_size, _ = buf_info_map[spilled_id]
                                    spill_out, spill_in = create_spill_nodes(
                                        new_node_id_counter, spilled_id, spilled_size, cache_type
                                    )
                                    new_nodes[spill_out["Id"]] = spill_out
                                    new_nodes[spill_in["Id"]] = spill_in
                                    new_node_id_counter += 2
                                    
                                    total_extra_data_movement += spilled_size
                                    final_schedule.append(spill_out["Id"])
                                    spill_log.append(f"{spilled_id}:{memory_allocation_history[spilled_id]}")
                                    spilled_ids.add(spilled_id)
                            
                            total_free = manager.get_total_free_space()
                            if total_free < size:
                                break
                        
                        victim_buf_id = manager.select_victim_for_spill()
                        if victim_buf_id is None:
                            break
                        
                        victim_size, _ = buf_info_map[victim_buf_id]
                        spill_out_node, spill_in_node = create_spill_nodes(
                            new_node_id_counter, victim_buf_id, victim_size, cache_type
                        )
                        new_nodes[spill_out_node["Id"]] = spill_out_node
                        new_nodes[spill_in_node["Id"]] = spill_in_node
                        new_node_id_counter += 2

                        total_extra_data_movement += victim_size
                        final_schedule.append(spill_out_node["Id"])
                        spill_log.append(f"{victim_buf_id}:{memory_allocation_history.get(victim_buf_id, 'unknown')}")
                        
                        manager.free(victim_buf_id)
                        spilled_ids.add(victim_buf_id)

                        offset = manager.allocate(buf_id, size)
                        if offset is not None:
                            memory_allocation_history[buf_id] = offset
                            final_schedule.append(nid)
                            
                            for spilled_id in list(spilled_ids):
                                spilled_size, _ = buf_info_map[spilled_id]
                                new_offset = manager.allocate(spilled_id, spilled_size)
                                if new_offset is not None:
                                    memory_allocation_history[spilled_id] = new_offset
                                    for node in new_nodes.values():
                                        if node.get("BufId") == spilled_id and node.get("Op") == "SPILL_IN":
                                            final_schedule.append(node["Id"])
                                            spill_log.append(f"{spilled_id}:{new_offset}")
                                            break
                                    spilled_ids.remove(spilled_id)
                            
                            spill_successful = True
                            break
            
            except UnboundLocalError:
                pass
            except Exception as e:
                raise RuntimeError(f"Error processing ALLOC node {nid}: {str(e)}") from e

        elif op == 'FREE':
            try:
                buf_id = int(current_node['BufId'])
                if buf_id in buf_info_map:
                    size, cache_type = buf_info_map[buf_id]
                    if cache_type in mem_managers:
                        mem_managers[cache_type].free(buf_id)
                final_schedule.append(nid)
            except Exception as e:
                raise RuntimeError(f"Error processing FREE node {nid}: {str(e)}") from e
                
        else:
            try:
                bufs_used = current_node.get('BufsList', [])
                for buf_id in bufs_used:
                    if buf_id in buf_info_map:
                        size, cache_type = buf_info_map[buf_id]
                        if cache_type in mem_managers:
                            mem_managers[cache_type].touch(buf_id)
                final_schedule.append(nid)
            except Exception as e:
                raise RuntimeError(f"Error processing node {nid}: {str(e)}") from e

    current_allocations = {}
    for manager in mem_managers.values():
        current_allocations.update(manager.get_current_allocations())
    combined_allocations = {**memory_allocation_history, **current_allocations}
    return final_schedule, combined_allocations, spill_log, total_extra_data_movement

# --- Helper Functions (no changes) ---
def load_problem1_schedule(task_name: str) -> List[int]:
    schedule_path = os.path.join(PROBLEM1_DIR, f"{task_name}_schedule.txt")
    if not os.path.exists(schedule_path):
        raise FileNotFoundError(f"Problem1 schedule file not found: {schedule_path}")
    with open(schedule_path, 'r', encoding='utf-8') as f:
        schedule = [int(line.strip()) for line in f if line.strip()]
    return schedule

def discover_tasks(data_dir: str) -> List[Tuple[str, str, str]]:
    files = os.listdir(data_dir)
    nodes = {f[:-10] for f in files if f.endswith("_Nodes.csv")}
    edges = {f[:-10] for f in files if f.endswith("_Edges.csv")}
    bases = sorted(nodes & edges)
    return [(b, f"{b}_Nodes.csv", f"{b}_Edges.csv") for b in bases]

def parse_bufs(x):
    if pd.isna(x):
        return []
    s = str(x)
    nums, cur = [], ''
    for ch in s:
        if ch.isdigit():
            cur += ch
        else:
            if cur:
                nums.append(int(cur)); cur = ''
    if cur:
        nums.append(int(cur))
    return nums

def load_graph(nodes_csv: str, edges_csv: str):
    df_nodes = pd.read_csv(nodes_csv)
    need_cols = ['Id','Op','BufId','Size','Type','Cycles','Pipe','Bufs']
    for c in need_cols:
        if c not in df_nodes.columns:
            df_nodes[c] = None
    df_nodes['Id'] = pd.to_numeric(df_nodes['Id'], errors='coerce').astype(int)
    for c in ['BufId','Size','Cycles']:
        df_nodes[c] = pd.to_numeric(df_nodes[c], errors='coerce')
    df_nodes['BufsList'] = df_nodes['Bufs'].apply(parse_bufs)
    
    df_edges = pd.read_csv(edges_csv)
    cols = {c.lower(): c for c in df_edges.columns}
    if 'startnodeid' not in cols or 'endnodeid' not in cols:
        raise ValueError("Edges CSV must contain columns: StartNodeId, EndNodeId")
    s_col, e_col = cols['startnodeid'], cols['endnodeid']
    df_edges[s_col] = pd.to_numeric(df_edges[s_col], errors='coerce').astype(int)
    df_edges[e_col] = pd.to_numeric(df_edges[e_col], errors='coerce').astype(int)

    node_ids = set(df_nodes['Id'].tolist())
    succ = {int(nid): [] for nid in node_ids}
    pred = {int(nid): [] for nid in node_ids}
    for _, e in df_edges.iterrows():
        u, v = int(e[s_col]), int(e[e_col])
        if u in node_ids and v in node_ids:
            succ[u].append(v)
            pred[v].append(u)
    return df_nodes, succ, pred

# --- Core Modification: CSV Output Format + Step Numbering in Schedule ---
def save_problem2_results(task_name: str, final_schedule: List[int], memory_allocation: Dict[int, int], spill_log: List[str], metrics: dict):
    """
    1. Schedule file: CSV with two columns (Step: 1-based sequence number, NodeId: node ID)
    2. Memory file: CSV with two columns (BufId: buffer ID, Offset: address offset)
    3. Spill file: CSV with two columns (BufId: buffer ID, NewOffset: new address offset)
    """
    # 1. Save schedule with step numbering
    sched_df = pd.DataFrame({
        'Step': list(range(1, len(final_schedule) + 1)),  # 1-based step index
        'NodeId': final_schedule
    })
    sched_path = os.path.join(OUT_DIR, f"{task_name}_schedule.csv")
    sched_df.to_csv(sched_path, index=False, encoding='utf-8')  # Do not write DataFrame index

    # 2. Save memory allocation results
    if memory_allocation:
        mem_df = pd.DataFrame({
            'BufId': list(memory_allocation.keys()),
            'Offset': list(memory_allocation.values())
        }).sort_values('BufId')  # Sort by BufId for readability
    else:
        # Preserve header even when empty
        mem_df = pd.DataFrame(columns=['BufId', 'Offset'])
    mem_path = os.path.join(OUT_DIR, f"{task_name}_memory.csv")
    mem_df.to_csv(mem_path, index=False, encoding='utf-8')

    # 3. Save SPILL operations (split BufId and NewOffset)
    spill_data = []
    for entry in spill_log:
        if ':' in entry:
            buf_id_str, offset_str = entry.split(':', 1)  # Split on first ":"
            if buf_id_str.isdigit() and offset_str.isdigit():
                spill_data.append({
                    'BufId': int(buf_id_str),
                    'NewOffset': int(offset_str)
                })
    if spill_data:
        spill_df = pd.DataFrame(spill_data)
    else:
        spill_df = pd.DataFrame(columns=['BufId', 'NewOffset'])
    spill_path = os.path.join(OUT_DIR, f"{task_name}_spill.csv")
    spill_df.to_csv(spill_path, index=False, encoding='utf-8')

    # 4. Save metrics (keep JSON format for analysis)
    metrics_path = os.path.join(OUT_DIR, f"{task_name}_metrics.json")
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2, ensure_ascii=False)

    # Console output
    print(f"[File Saved] Task {task_name} output completed:")
    print(f"  - Schedule: {len(final_schedule)} records → {sched_path}")
    print(f"  - Memory allocations: {len(memory_allocation)} records → {mem_path}")
    print(f"  - SPILL operations: {len(spill_data)} records → {spill_path}")

# --- Task Execution Function (adapted for CSV output) ---
def run_problem2_for_task(task_name: str, nodes_csv: str, edges_csv: str):
    print(f"\n[Problem2] Starting task: {task_name}")
    
    nodes_path = os.path.join(DATA_DIR, nodes_csv)
    edges_path = os.path.join(DATA_DIR, edges_csv)
    df_nodes, _, _ = load_graph(nodes_path, edges_path)

    schedule_p1 = load_problem1_schedule(task_name)

    final_schedule, memory_allocation, spill_log, extra_data = process_schedule_with_memory_management(df_nodes, schedule_p1)

    new_makespan = "Not Computed"
    metrics = {
        "task": task_name,
        "original_nodes_count": len(df_nodes),
        "final_nodes_count": len(final_schedule),
        "spill_operations_count": len(spill_log) // 2,
        "memory_allocations_count": len(memory_allocation),
        "total_extra_data_movement": extra_data,
        "new_makespan_cycles": new_makespan
    }
    
    save_problem2_results(task_name, final_schedule, memory_allocation, spill_log, metrics)
    
    print(f"[Problem2] Completed task: {task_name} | SPILL count: {metrics['spill_operations_count']} | "
          f"Memory records: {metrics['memory_allocations_count']} | Extra data moved: {metrics['total_extra_data_movement']}")
    return metrics

# --- Main Entry Point (no changes) ---
if __name__ == "__main__":
    candidates = discover_tasks(DATA_DIR)
    all_metrics = []
    for task_name, nodes_fn, edges_fn in candidates:
        try:
            metrics = run_problem2_for_task(task_name, nodes_fn, edges_fn)
            all_metrics.append(metrics)
        except Exception as e:
            print(f"Error processing task {task_name}: {str(e)}")
            continue
    
    summary_path = os.path.join(OUT_DIR, "summary_metrics.json")
    with open(summary_path, "w", encoding="utf-8") as f:
        json.dump(all_metrics, f, indent=2, ensure_ascii=False)
    
    print("\n[Problem2] All tasks processed. Summary saved to summary_metrics.json")
    for m in all_metrics:
        print(f"- {m['task']}: SPILL={m['spill_operations_count']}, MemoryRecords={m['memory_allocations_count']}, ExtraData={m['total_extra_data_movement']}")


[Problem2] 开始处理任务: Conv_Case0
[文件保存] 任务 Conv_Case0 输出完成：
  - 调度序列：3039 条记录 → ./CSV版本\Attachment\Problem2\Conv_Case0_schedule.csv
  - 内存分配：817 条记录 → ./CSV版本\Attachment\Problem2\Conv_Case0_memory.csv
  - SPILL操作：473 条记录 → ./CSV版本\Attachment\Problem2\Conv_Case0_spill.csv
[Problem2] 完成任务: Conv_Case0 | SPILL次数: 236 | 内存记录数: 817 | 额外数据搬运: 57146

[Problem2] 开始处理任务: Conv_Case1
[文件保存] 任务 Conv_Case1 输出完成：
  - 调度序列：47108 条记录 → ./CSV版本\Attachment\Problem2\Conv_Case1_schedule.csv
  - 内存分配：11386 条记录 → ./CSV版本\Attachment\Problem2\Conv_Case1_memory.csv
  - SPILL操作：11649 条记录 → ./CSV版本\Attachment\Problem2\Conv_Case1_spill.csv
[Problem2] 完成任务: Conv_Case1 | SPILL次数: 5824 | 内存记录数: 11386 | 额外数据搬运: 451707

[Problem2] 开始处理任务: FlashAttention_Case0
[文件保存] 任务 FlashAttention_Case0 输出完成：
  - 调度序列：2105 条记录 → ./CSV版本\Attachment\Problem2\FlashAttention_Case0_schedule.csv
  - 内存分配：529 条记录 → ./CSV版本\Attachment\Problem2\FlashAttention_Case0_memory.csv
  - SPILL操作：432 条记录 → ./CSV版本\Attachment\Problem2\FlashAttention_Cas

In [None]:
import os
import sys
import csv
import re
from pathlib import Path
from collections import defaultdict

try:
    from PyQt5 import QtWidgets, QtCore
    from PyQt5.QtWidgets import QApplication, QMessageBox, QTableWidgetItem, QDialog, QFileDialog
    QT_LIB = "PyQt5"
    from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
except Exception:
    from PyQt6 import QtWidgets, QtCore
    from PyQt6.QtWidgets import QApplication, QMessageBox, QTableWidgetItem, QDialog, QFileDialog
    QT_LIB = "PyQt6"
    from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas

import matplotlib
matplotlib.use("QtAgg")
import matplotlib.pyplot as plt

# Set Chinese font support for matplotlib
plt.rcParams['font.sans-serif'] = [
    'SimHei', 'Microsoft YaHei', 'HanaMinA', 'Source Han Sans CN',
    'Noto Sans CJK SC', 'WenQuanYi Zen Hei', 'PingFang SC', 'sans-serif'
]
plt.rcParams['axes.unicode_minus'] = False

APP_TITLE = f"Problem 2: Cache Allocation and SPILL Scheduler ({QT_LIB})"

# Cache capacity configuration (unit: bytes)
CAPACITY = {"L1":4096, "UB":1024, "L0A":256, "L0B":256, "L0C":512}

def parse_nodes_edges(csv_dir: Path, task_name: str):
    """
    Parse nodes and edges from CSV files for specific task
    Args:
        csv_dir: Directory containing CSV files
        task_name: Name of the target task
    Returns:
        Parsed nodes, edges and buffer-related metadata
    Raises:
        FileNotFoundError: If required CSV files are missing
    """
    nodes_path = csv_dir / f"{task_name}_Nodes.csv"
    edges_path = csv_dir / f"{task_name}_Edges.csv"
    if not nodes_path.exists() or not edges_path.exists():
        raise FileNotFoundError(f"Missing CSV files: {task_name}")

    nodes = {}
    alloc_of_buf = {}  # buffer id -> allocate node id
    free_of_buf  = {}  # buffer id -> free node id
    bufs_used_by_node = defaultdict(list)  # node id -> list of buffer ids used
    buf_attr = {}  # buffer id -> (type, size)
    copyin_uses_buf = set()  # buffers used by COPY_IN operations

    with open(nodes_path, newline='', encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            nid = int(row["Id"])
            op = (row["Op"] or "").strip().upper()
            row["Op"] = op
            nodes[nid] = row

            if op == "ALLOC":
                buf = int(row["BufId"])
                size = int(row["Size"])
                typ_raw = (row["Type"] or "").strip().upper()
                # Correct typo in buffer type naming
                typ = (typ_raw
                       .replace("LOA","L0A")
                       .replace("LOB","L0B")
                       .replace("LOC","L0C"))
                alloc_of_buf[buf] = nid
                buf_attr[buf] = (typ, size)

            elif op == "FREE":
                buf = int(row["BufId"])
                free_of_buf[buf] = nid

            else:
                bufs_str = (row.get("Bufs") or "").strip()
                if bufs_str:
                    # Clean up buffer string format
                    s = re.sub(r'[\"\[\]\s]', '', bufs_str)
                    if s:
                        for x in s.split(","):
                            if x != "":
                                b = int(x)
                                bufs_used_by_node[nid].append(b)
                                if op == "COPY_IN":
                                    copyin_uses_buf.add(b)

    edges = []
    with open(edges_path, newline='', encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            edges.append((int(row["StartNodeId"]), int(row["EndNodeId"])))

    # Calculate maximum node id + 1 as node count reference
    N = max(nodes.keys())+1 if nodes else 0
    return nodes, edges, alloc_of_buf, free_of_buf, bufs_used_by_node, buf_attr, copyin_uses_buf, N

def read_schedule(problem1_dir: Path, task_name: str):
    """
    Read schedule sequence from Problem1 output file
    Args:
        problem1_dir: Directory containing Problem1 schedule files
        task_name: Name of the target task
    Returns:
        List of node ids in execution order
    Raises:
        FileNotFoundError: If schedule file is missing
    """
    path = problem1_dir / f"{task_name}_schedule.txt"
    if not path.exists():
        raise FileNotFoundError(f"Missing Problem1 schedule sequence: {path}")
    S = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            s = line.strip()
            if s:
                S.append(int(s))
    return S

class FreeList:
    """
    Memory free list management with best-fit allocation strategy
    """
    def __init__(self, capacity:int):
        """Initialize free list with a single interval covering full capacity"""
        self.intervals = [(0, capacity)]

    def _merge(self):
        """Merge contiguous free intervals to reduce fragmentation"""
        iv = self.intervals
        iv.sort()
        merged = []
        for s,l in iv:
            if not merged:
                merged.append((s,l))
            else:
                ps,pl = merged[-1]
                # Merge if current interval starts right after previous ends
                if ps+pl == s:
                    merged[-1] = (ps,pl+l)
                else:
                    merged.append((s,l))
        self.intervals = merged

    def alloc_best_fit(self, size:int):
        """
        Allocate memory using best-fit strategy
        Args:
            size: Required memory size
        Returns:
            Start offset if allocation succeeds, None otherwise
        """
        best_idx, best_len = -1, None
        for i,(s,l) in enumerate(self.intervals):
            if l>=size:
                if best_len is None or l<best_len:
                    best_idx, best_len = i, l
        if best_idx<0:
            return None
        s,l = self.intervals[best_idx]
        start = s
        # Remove interval if exact match, otherwise shrink it
        if l == size:
            self.intervals.pop(best_idx)
        else:
            self.intervals[best_idx] = (s+size, l-size)
        return start

    def free(self, start:int, size:int):
        """
        Free memory interval and merge contiguous intervals
        Args:
            start: Start offset of memory to free
            size: Size of memory to free
        """
        self.intervals.append((start,size))
        self._merge()

class AllocState:
    """Data structure to track buffer allocation state"""
    __slots__ = ("buf","off","size","next_use_pos")
    def __init__(self, buf, off, size, next_use_pos):
        self.buf = buf          # buffer id
        self.off = off          # memory offset
        self.size = size        # buffer size
        self.next_use_pos = next_use_pos  # next usage position in schedule

def solve_task(csv_dir:Path, problem1_dir:Path, problem2_dir:Path, task_name:str):
    """
    Main logic to solve Problem 2: cache allocation with SPILL scheduling
    Args:
        csv_dir: Directory with CSV input files
        problem1_dir: Directory with Problem1 schedule files
        problem2_dir: Directory for Problem2 output files
        task_name: Name of target task
    Returns:
        Dictionary containing task execution metrics
    """
    (nodes, edges, alloc_of_buf, free_of_buf, bufs_used_by_node,
     buf_attr, copyin_uses_buf, N0) = parse_nodes_edges(csv_dir, task_name)
    S = read_schedule(problem1_dir, task_name)
    # Map node id to its position in schedule
    pos = {nid:i for i,nid in enumerate(S)}

    # Track all usage positions of each buffer
    uses = defaultdict(list)
    for nid in S:
        for b in bufs_used_by_node.get(nid,[]):
            uses[b].append(pos[nid])

    # Initialize free lists for different cache types
    freelists = {t:FreeList(CAPACITY[t]) for t in CAPACITY.keys()}
    # Track live (allocated) buffers per cache type
    live = {t:{} for t in CAPACITY.keys()}
    extra_movement = 0  # Total extra data movement caused by SPILL operations
    spill_ops = []      # Record of SPILL operations (buffer id, offset)
    next_new_id = N0    # Next available id for new SPILL nodes

    # Track first allocation offset of each buffer
    first_offset = {}

    # New nodes to insert before specific positions in schedule
    inserts_before = defaultdict(list)

    # Cursor to track next usage position for each buffer
    use_cursor = {b:0 for b in buf_attr.keys()}
    def next_use_pos_of(b, curpos):
        """
        Find next usage position of buffer after current position
        Args:
            b: Buffer id
            curpos: Current position in schedule
        Returns:
            Next usage position (infinite if no future use)
        """
        arr = uses.get(b,[])
        k = use_cursor[b]
        # Move cursor past positions <= current position
        while k < len(arr) and arr[k] <= curpos:
            k += 1
        use_cursor[b] = k
        return arr[k] if k < len(arr) else float("inf")

    def evict_one(type_name, curpos):
        """
        Evict one buffer from specified cache type based on eviction strategy
        Args:
            type_name: Cache type (L1/UB/L0A/L0B/L0C)
            curpos: Current position in schedule
        Returns:
            (evicted buffer id, buffer size)
        """
        nonlocal extra_movement, next_new_id
        pool = list(live[type_name].values())
        if not pool:
            return (None, 0)
        
        if type_name in ("L0A", "L0B", "L0C"):
            # Eviction strategy for L0A/B/C: 
            # Prioritize evicting buffers with no future use or farthest next use
            nofuture = [x for x in pool if x.next_use_pos == float("inf")]
            if nofuture:
                victim = nofuture[0]  # Any buffer with no future use
            else:
                victim = max(pool, key=lambda x: x.next_use_pos)  # Farthest next use
        else:
            # Eviction strategy for L1/UB: cost-weighted Belady-like
            # Priority P(b*) = (distance to next use) * (spill_cost / size)
            priorities = []
            for item in pool:
                # Calculate spill cost per unit size
                cost_per_unit = 0
                spill_out_c = 0 if item.buf in copyin_uses_buf else item.size
                spill_in_c = item.size
                cost_per_unit = (spill_out_c + spill_in_c) / item.size if item.size > 0 else 0

                # Time distance to next use
                time_dist = (item.next_use_pos - curpos) if item.next_use_pos != float("inf") else float("inf")
                
                # Adjustment: Buffers with infinite next_use_pos should be evicted first
                # Multiply by -1 to make max() select the buffer most needing eviction
                priority_score = -time_dist * cost_per_unit
                if time_dist == float("inf"):
                    priority_score = float("-inf")  # Lowest score (highest eviction priority)

                priorities.append((priority_score, item))
            
            victim = max(priorities, key=lambda x: x[0])[1]

        b = victim.buf
        spill_out_id = next_new_id
        next_new_id += 1
        inserts_before[curpos].append(spill_out_id)
        
        # Count extra data movement for spill operations
        if b not in copyin_uses_buf:
            extra_movement += victim.size  # SpillOut cost
            extra_movement += victim.size  # SpillIn cost
        
        # Free the evicted buffer's memory
        freelists[type_name].free(victim.off, victim.size)
        del live[type_name][b]
        return (b, victim.size)

    def ensure_alloc(b, curpos):
        """
        Ensure buffer is allocated (evict other buffers if necessary)
        Args:
            b: Buffer id to allocate
            curpos: Current position in schedule
        Returns:
            Allocated memory offset
        Raises:
            RuntimeError: If allocation fails even after eviction
        """
        nonlocal extra_movement, next_new_id
        typ, size = buf_attr[b]

        # Special handling for L0A/B/C: evict all live buffers first
        if typ in ("L0A","L0B","L0C"):
            while live[typ]:
                evict_one(typ, curpos)

        # Try allocation with eviction loop
        while True:
            off = freelists[typ].alloc_best_fit(size)
            if off is not None:
                nu = next_use_pos_of(b, curpos)
                live[typ][b] = AllocState(b, off, size, nu)
                if b not in first_offset:
                    first_offset[b] = off
                return off
            
            # Evict one buffer and retry if allocation fails
            victim_buf, _ = evict_one(typ, curpos)
            if victim_buf is None:
                live_bufs = list(live[typ].keys())
                raise RuntimeError(
                    f"Failed to allocate {typ} at position {curpos} (Buffer ID: {b}, Size: {size}, Capacity: {CAPACITY[typ]}). "
                    f"Currently resident Buffer IDs: {live_bufs}. No buffers available for eviction."
                )

    # Preprocess buffers used by each node
    bufs_of_node = defaultdict(list)
    for nid in nodes:
        for b in bufs_used_by_node.get(nid,[]):
            bufs_of_node[nid].append(b)

    new_nodes_added = 0
    # Process each node in schedule sequence
    for i, nid in enumerate(S):
        row = nodes[nid]
        op = (row["Op"] or "").strip().upper()

        # Handle buffers used by current node
        for b in bufs_of_node.get(nid, []):
            typ_size = buf_attr.get(b)
            if not typ_size:
                continue
            typ, size = typ_size
            # Allocate buffer if not already live (add SPILL_IN node)
            if b not in live[typ]:
                spill_in_id = next_new_id
                next_new_id+=1
                inserts_before[i].append(spill_in_id)
                new_nodes_added += 1
                
                extra_movement += size  # Count SPILL_IN data movement
                
                off = ensure_alloc(b, i)
                spill_ops.append((b, off))

        # Handle ALLOC/FREE operations
        if op == "ALLOC":
            b = int(row["BufId"])
            ensure_alloc(b, i)
        elif op == "FREE":
            b = int(row["BufId"])
            typ, size = buf_attr[b]
            if b in live[typ]:
                # Free buffer memory
                freelists[typ].free(live[typ][b].off, size)
                del live[typ][b]
        else:
            # Update next use position for live buffers
            for b in bufs_of_node.get(nid, []):
                typ, size = buf_attr[b]
                if b in live[typ]:
                    live[typ][b].next_use_pos = next_use_pos_of(b, i)

    # Generate new schedule with inserted SPILL nodes
    new_schedule = []
    for i, nid in enumerate(S):
        if i in inserts_before:
            new_schedule.extend(inserts_before[i])
        new_schedule.append(nid)

    # Create output directory and write results
    problem2_dir.mkdir(parents=True, exist_ok=True)
    # Write new schedule
    with open(problem2_dir/f"{task_name}_schedule.txt", "w", encoding="utf-8") as f:
        for x in new_schedule:
            f.write(str(x)+"\n")
    # Write buffer first allocation offsets
    with open(problem2_dir/f"{task_name}_memory.txt", "w", encoding="utf-8") as f:
        for b in sorted(first_offset.keys()):
            f.write(f"{b}:{first_offset[b]}\n")
    # Write SPILL operations
    with open(problem2_dir/f"{task_name}_spill.txt", "w", encoding="utf-8") as f:
        for b,off in spill_ops:
            f.write(f"{b}:{off}\n")

    # Return task metrics
    return {
        "task": task_name,
        "extra_movement": int(extra_movement),
        "spill_count": len(spill_ops),
        "new_nodes_added": new_nodes_added,
        "schedule_len": len(new_schedule)
    }

class MplCanvas(FigureCanvas):
    """Matplotlib canvas wrapper for Qt integration"""
    def __init__(self, parent=None, width=6, height=3.6, dpi=110):
        fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
        super().__init__(fig)
        self.setParent(parent)
        self.ax = ax
        self.fig = fig

class ChartDialog(QDialog):
    """Dialog window for displaying matplotlib charts"""
    def __init__(self, parent, title, plot_fn):
        super().__init__(parent)
        self.setWindowTitle(title)
        self.resize(820, 520)
        v = QtWidgets.QVBoxLayout(self)
        self.canvas = MplCanvas(self, width=6.8, height=3.8, dpi=110)
        v.addWidget(self.canvas)
        # Execute plot function to draw chart
        plot_fn(self.canvas.ax)
        self.canvas.fig.tight_layout()
        self.canvas.draw()
        # Add close button
        btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.StandardButton.Close if QT_LIB=="PyQt6"
                                         else QtWidgets.QDialogButtonBox.Close)
        btn.rejected.connect(self.reject)
        btn.accepted.connect(self.accept)
        v.addWidget(btn)

class MainWin(QtWidgets.QMainWindow):
    """Main application window"""
    def __init__(self):
        super().__init__()
        self.setWindowTitle(APP_TITLE)
        self.resize(1000, 680)
        w = QtWidgets.QWidget()
        self.setCentralWidget(w)
        main_layout = QtWidgets.QVBoxLayout(w)

        # Button row
        btn_row = QtWidgets.QHBoxLayout()
        self.btn_select_csv = QtWidgets.QPushButton("Select CSV Version Directory")
        self.btn_select_prob1 = QtWidgets.QPushButton("Select Problem1 Directory")
        self.btn_run = QtWidgets.QPushButton("Run Problem 2 (All Test Cases)")
        self.btn_copy = QtWidgets.QPushButton("Copy Table to Clipboard")
        self.btn_chart1 = QtWidgets.QPushButton("View Chart 1 (Extra Data Movement per Task)")
        self.btn_chart2 = QtWidgets.QPushButton("View Chart 2 (SPILL Count vs New Schedule Length)")
        btn_row.addWidget(self.btn_select_csv)
        btn_row.addWidget(self.btn_select_prob1)
        btn_row.addWidget(self.btn_run)
        btn_row.addStretch(1)
        btn_row.addWidget(self.btn_copy)
        btn_row.addWidget(self.btn_chart1)
        btn_row.addWidget(self.btn_chart2)
        main_layout.addLayout(btn_row)

        # Path display label
        self.lbl_paths = QtWidgets.QLabel("Please select CSV Version and Problem1 directories.")
        self.lbl_paths.setWordWrap(True)
        main_layout.addWidget(self.lbl_paths)

        # Results table
        self.table = QtWidgets.QTableWidget(0,5)
        self.table.setHorizontalHeaderLabels(["Task","Extra Data Movement","SPILL Count","New Nodes Added","New Schedule Length"])
        self.table.horizontalHeader().setStretchLastSection(True)
        main_layout.addWidget(self.table, stretch=1)

        # Status bar
        self.status = QtWidgets.QLabel("Ready.")
        main_layout.addWidget(self.status)

        # Connect button signals
        self.btn_select_csv.clicked.connect(self.on_select_csv_dir)
        self.btn_select_prob1.clicked.connect(self.on_select_prob1_dir)
        self.btn_run.clicked.connect(self.on_run_all)
        self.btn_copy.clicked.connect(self.on_copy)
        self.btn_chart1.clicked.connect(self.on_show_chart1)
        self.btn_chart2.clicked.connect(self.on_show_chart2)

        # Initialize directory variables
        self.csv_dir = None
        self.prob1_dir = None
        self.prob2_dir = None
        # List of test tasks
        self.tasks = [
            "Conv_Case0","Conv_Case1",
            "FlashAttention_Case0","FlashAttention_Case1",
            "Matmul_Case0","Matmul_Case1",
        ]
        self.results_cache = []  # Cache for task results

    def update_path_display(self):
        """Update path display label with current directory selections"""
        csv_path_str = str(self.csv_dir) if self.csv_dir else "Not selected"
        prob1_path_str = str(self.prob1_dir) if self.prob1_dir else "Not selected"
        prob2_path_str = str(self.prob2_dir) if self.prob2_dir else "To be determined"
        self.lbl_paths.setText(
            f"CSV Version Directory: {csv_path_str}\nProblem1 Directory: {prob1_path_str}\nProblem2 (Output) Directory: {prob2_path_str}"
        )
        if self.csv_dir and self.prob1_dir:
            self.status.setText("Paths Ready ✅")
        else:
            self.status.setText("Please select all required directories.")

    def on_select_csv_dir(self):
        """Handle CSV directory selection"""
        dir_path = QFileDialog.getExistingDirectory(self, "Select CSV Version Directory")
        if dir_path:
            self.csv_dir = Path(dir_path)
            # Set Problem2 output directory based on CSV directory location
            self.prob2_dir = self.csv_dir.parent / "Problem2" if self.csv_dir.name == "CSV版本" else self.csv_dir / "Problem2"
            self.update_path_display()

    def on_select_prob1_dir(self):
        """Handle Problem1 directory selection"""
        dir_path = QFileDialog.getExistingDirectory(self, "Select Problem1 Directory")
        if dir_path:
            self.prob1_dir = Path(dir_path)
            self.update_path_display()

    def on_run_all(self):
        """Run Problem 2 for all test cases"""
        if not (self.csv_dir and self.prob1_dir):
            QMessageBox.warning(self, "Prompt", "Please select CSV Version and Problem1 directories first.")
            return
        # Clear previous results
        self.table.setRowCount(0)
        self.results_cache = []
        # Process each task
        for t in self.tasks:
            try:
                info = solve_task(self.csv_dir, self.prob1_dir, self.prob2_dir, t)
                self.results_cache.append(info)
                self.append_row(info)
                # Update UI during processing
                QApplication.processEvents()
            except Exception as e:
                QMessageBox.critical(self, "Error", f"{t} failed:\n{e}")
                return
        # Calculate total metrics
        total_mv = sum(x["extra_movement"] for x in self.results_cache)
        total_sp = sum(x["spill_count"] for x in self.results_cache)
        self.status.setText(f"Completed ✅  Total Extra Data Movement={total_mv}, Total SPILL Count={total_sp}. Results exported to {self.prob2_dir}.")
        QMessageBox.information(self, "Completed", "All test cases for Problem 2 have been processed. Detailed results are in the Problem2 folder.")

    def append_row(self, info):
        """Append task result row to table"""
        r = self.table.rowCount()
        self.table.insertRow(r)
        self.table.setItem(r,0,QTableWidgetItem(info["task"]))
        self.table.setItem(r,1,QTableWidgetItem(str(info["extra_movement"])))
        self.table.setItem(r,2,QTableWidgetItem(str(info["spill_count"])))
        self.table.setItem(r,3,QTableWidgetItem(str(info["new_nodes_added"])))
        self.table.setItem(r,4,QTableWidgetItem(str(info["schedule_len"])))

    def on_copy(self):
        """Copy table contents to clipboard"""
        rows = self.table.rowCount()
        cols = self.table.columnCount()
        lines = []
        # Get header row
        header = [self.table.horizontalHeaderItem(c).text() for c in range(cols)]
        lines.append("\t".join(header))
        # Get data rows
        for r in range(rows):
            row = []
            for c in range(cols):
                it = self.table.item(r,c)
                row.append(it.text() if it else "")
            lines.append("\t".join(row))
        # Join and copy to clipboard
        txt = "\n".join(lines)
        cb = QApplication.clipboard()
        cb.setText(txt)
        QMessageBox.information(self, "Copied", "Statistics table copied to clipboard ✅")

    def on_show_chart1(self):
        """Show Chart 1: Extra Data Movement per Task (Bar Chart)"""
        if not self.results_cache:
            QMessageBox.information(self, "Prompt", "Please run Problem 2 first before viewing charts.")
            return
        def plot_bar(ax):
            tasks = [x["task"] for x in self.results_cache]
            mv    = [x["extra_movement"] for x in self.results_cache]
            ax.bar(tasks, mv)
            ax.set_title("Extra Data Movement by Task")
            ax.set_xlabel("Task")
            ax.set_ylabel("Extra Data Movement (Bytes)")
            ax.tick_params(axis='x', rotation=30)
        dlg = ChartDialog(self, "Chart 1: Extra Data Movement per Task", plot_bar)
        dlg.exec()

    def on_show_chart2(self):
        """Show Chart 2: SPILL Count vs New Schedule Length (Scatter Plot)"""
        if not self.results_cache:
            QMessageBox.information(self, "Prompt", "Please run Problem 2 first before viewing charts.")
            return
        def plot_scatter(ax):
            tasks = [x["task"] for x in self.results_cache]
            sp    = [x["spill_count"] for x in self.results_cache]
            slen  = [x["schedule_len"] for x in self.results_cache]
            ax.scatter(slen, sp)
            # Add task labels to scatter points
            for i, name in enumerate(tasks):
                ax.annotate(name, (slen[i], sp[i]), xytext=(5,5), textcoords="offset points", fontsize=8)
            ax.set_title("SPILL Count vs New Schedule Length")
            ax.set_xlabel("New Schedule Length")
            ax.set_ylabel("SPILL Count")
        dlg = ChartDialog(self, "Chart 2: SPILL Count vs New Schedule Length", plot_scatter)
        dlg.exec()

def main():
    """Application entry point"""
    app = QApplication(sys.argv)
    win = MainWin()
    win.show()
    try:
        sys.exit(app.exec_())
    except AttributeError:
        # Compatibility for PyQt5/PyQt6 exec() difference
        sys.exit(app.exec())

if __name__ == "__main__":
    main()