In [None]:
import os
import json
import pandas as pd
import random
import numpy as np
import time
from collections import deque, defaultdict
from typing import List, Tuple  # New import to resolve undefined List

# ---------- Configuration ----------
DATA_DIR = r"./CSV版本"
OUT_DIR = os.path.join(DATA_DIR, "Attachment", "Problem1")
os.makedirs(OUT_DIR, exist_ok=True)

# ---------- Utility Functions ----------

def parse_bufs(x):
    """Parse Bufs field into a list of integers."""
    if pd.isna(x):
        return []
    s = str(x)
    nums, cur = [], ''
    for ch in s:
        if ch.isdigit():
            cur += ch
        else:
            if cur:
                nums.append(int(cur))
                cur = ''
    if cur:
        nums.append(int(cur))
    return nums

def load_graph(nodes_csv: str, edges_csv: str):
    """Load graph from CSV files."""
    df_nodes = pd.read_csv(nodes_csv)
    
    need_cols = ['Id','Op','BufId','Size','Type','Cycles','Pipe','Bufs']
    for c in need_cols:
        if c not in df_nodes.columns:
            df_nodes[c] = None
    
    df_nodes['Id'] = pd.to_numeric(df_nodes['Id'], errors='coerce').astype(int)
    for c in ['BufId','Size','Cycles']:
        df_nodes[c] = pd.to_numeric(df_nodes[c], errors='coerce')
    df_nodes['BufsList'] = df_nodes['Bufs'].apply(parse_bufs)

    df_edges = pd.read_csv(edges_csv)
    cols = {c.lower(): c for c in df_edges.columns}
    if 'startnodeid' not in cols or 'endnodeid' not in cols:
        raise ValueError("Edges CSV must contain columns: StartNodeId, EndNodeId")
    s_col, e_col = cols['startnodeid'], cols['endnodeid']
    df_edges[s_col] = pd.to_numeric(df_edges[s_col], errors='coerce').astype(int)
    df_edges[e_col] = pd.to_numeric(df_edges[e_col], errors='coerce').astype(int)

    node_ids = set(df_nodes['Id'].tolist())
    succ = defaultdict(list)
    pred = defaultdict(list)
    for _, e in df_edges.iterrows():
        u, v = int(e[s_col]), int(e[e_col])
        if u in node_ids and v in node_ids:
            succ[u].append(v)
            pred[v].append(u)
    
    return df_nodes, succ, pred

def memory_delta(row) -> int:
    """Compute memory change caused by a node."""
    op = str(row['Op'])
    if op == 'ALLOC':
        return int(row['Size']) if not pd.isna(row['Size']) else 0
    if op == 'FREE':
        return -int(row['Size']) if not pd.isna(row['Size']) else 0
    return 0

# ---------- Innovative Algorithm: Two-Stage Hybrid Optimization Strategy ----------

class InnovativeMemoryScheduler:
    """Innovative two-stage hybrid memory scheduling algorithm"""
    
    def __init__(self, df_nodes, succ, pred):
        self.df_nodes = df_nodes
        self.succ = succ
        self.pred = pred
        
        self.node_ops = {int(r['Id']): str(r['Op']) for _, r in df_nodes.iterrows()}
        self.node_sizes = {int(r['Id']): int(r['Size']) if not pd.isna(r['Size']) else 0 for _, r in df_nodes.iterrows()}
        self.node_delta = {int(r['Id']): memory_delta(r) for _, r in df_nodes.iterrows()}
        
        self.num_nodes = len(self.node_delta)
        self.all_nodes = set(self.node_delta.keys())
        
        # Precompute dependency sets
        self.pred_set = {nid: set(preds) for nid, preds in pred.items()}
        self.succ_set = {nid: set(succs) for nid, succs in succ.items()}

        # Precompute node type sets
        self.alloc_nodes = {nid for nid, op in self.node_ops.items() if op == 'ALLOC'}
        self.free_nodes = {nid for nid, op in self.node_ops.items() if op == 'FREE'}
        self.other_nodes = self.all_nodes - self.alloc_nodes - self.free_nodes
        
        # Precompute criticality scores
        self.criticality_score = self._compute_criticality_scores()
        
    def _compute_criticality_scores(self):
        """Compute critical path scores (based on longest path)."""
        score = {nid: 0 for nid in self.all_nodes}
        visited = set()
        
        def dfs(nid):
            if nid in visited:
                return score[nid]
            max_child_score = 0
            for neighbor in self.succ.get(nid, []):
                child_score = dfs(neighbor)
                if child_score > max_child_score:
                    max_child_score = child_score
            # Assume unit weight per node; score = longest path length from this node
            score[nid] = 1 + max_child_score
            visited.add(nid)
            return score[nid]
            
        for nid in self.all_nodes:
            if nid not in visited:
                dfs(nid)
                
        return score
    
    def _calculate_priority(self, node, current_memory, alpha=0.5, beta=0.5):
        """Calculate node priority combining criticality and memory impact."""
        # Normalized criticality score ([0, 100])
        norm_criticality = (self.criticality_score[node] / self.num_nodes) * 100.0
        
        # Memory impact score
        op = self.node_ops.get(node, '')
        delta_val = self.node_delta.get(node, 0)
        
        if op == 'FREE':
            # Higher score for freeing more memory
            memory_score = 100 + min(100, abs(delta_val) / 1024)  # Assume 1KB unit, cap at 100
        elif op == 'ALLOC':
            # Lower allocation size → higher score
            memory_score = 50 - min(50, abs(delta_val) / 1024)
        else:
            memory_score = 50
            
        # Adjust based on current memory pressure
        memory_pressure = current_memory / (max(1, self._get_max_possible_memory()) + 1)
        if op == 'FREE':
            memory_score *= (1 + memory_pressure)
        elif op == 'ALLOC':
            memory_score *= (1 - memory_pressure)
            
        # Combined priority
        return alpha * norm_criticality + beta * memory_score
    
    def _get_max_possible_memory(self):
        return sum(d for d in self.node_delta.values() if d > 0)
    
    def _compute_peak_memory(self, schedule):
        current_mem, peak = 0, 0
        for node in schedule:
            current_mem += self.node_delta.get(node, 0)
            current_mem = max(0, current_mem)
            if current_mem > peak:
                peak = current_mem
        return peak
    
    def _is_valid_schedule(self, schedule):
        """Check if the schedule respects all dependencies."""
        if len(set(schedule)) != self.num_nodes:
            return False
            
        position = {}
        for idx, node in enumerate(schedule):
            for pred_node in self.pred_set.get(node, set()):
                if pred_node not in position:
                    return False
                if position[pred_node] >= idx:
                    return False
            position[node] = idx
        return True
    
    # ---------- Stage 1: Critical-Path & Memory-Aware Greedy Construction ----------
    def cp_memory_aware_construction(self) -> List[int]:
        """Greedy construction using critical-path and memory awareness."""
        in_degree = {nid: len(self.pred_set.get(nid, set())) for nid in self.all_nodes}
        ready = deque([nid for nid, deg in in_degree.items() if deg == 0])
        
        schedule = []
        current_memory = 0
        
        while ready:
            # Select highest-priority node from ready queue
            best_node = None
            best_score = float('-inf')
            
            for node in ready:
                score = self._calculate_priority(node, current_memory)
                if score > best_score:
                    best_score = score
                    best_node = node
            
            if best_node is None:
                best_node = ready.popleft()
            else:
                ready.remove(best_node)
                
            schedule.append(best_node)
            current_memory += self.node_delta.get(best_node, 0)
            current_memory = max(0, current_memory)
            
            for neighbor in self.succ_set.get(best_node, set()):
                in_degree[neighbor] -= 1
                if in_degree[neighbor] == 0:
                    ready.append(neighbor)
                    
        return schedule
    
    # ---------- Stage 2: Simulated Annealing with Heuristic Neighbor Generation ----------
    def _generate_heuristic_neighbor(self, schedule):
        """Generate a heuristic neighbor by swapping promising node pairs."""
        schedule_copy = schedule.copy()
        n = len(schedule_copy)
        if n <= 1:
            return schedule_copy
            
        # Try to find good swap candidates
        swap_candidates = []
        # 1. Prefer FREE ↔ ALLOC swaps
        for i in range(n):
            if schedule_copy[i] in self.free_nodes:
                for j in range(n):
                    if i != j and schedule_copy[j] in self.alloc_nodes:
                        swap_candidates.append((i, j))
                        
        # 2. If none, try FREE ↔ OTHER
        if not swap_candidates:
            for i in range(n):
                if schedule_copy[i] in self.free_nodes:
                    for j in range(n):
                        if i != j and schedule_copy[j] in self.other_nodes:
                            swap_candidates.append((i, j))
        
        # 3. Fallback: random swap
        if not swap_candidates:
            i, j = random.sample(range(n), 2)
            swap_candidates.append((i, j))
            
        # Randomly pick one candidate and swap
        i, j = random.choice(swap_candidates)
        schedule_copy[i], schedule_copy[j] = schedule_copy[j], schedule_copy[i]
        
        return schedule_copy
        
    def simulated_annealing_optimization(self, initial_schedule, initial_temp=1000, cooling_rate=0.95, iterations=800):
        """Simulated annealing optimization."""
        current_schedule = initial_schedule.copy()
        current_peak = self._compute_peak_memory(current_schedule)
        best_schedule = current_schedule.copy()
        best_peak = current_peak
        
        temp = initial_temp
        
        for _ in range(iterations):
            neighbor = self._generate_heuristic_neighbor(current_schedule)
            if not self._is_valid_schedule(neighbor):
                continue
                
            neighbor_peak = self._compute_peak_memory(neighbor)
            
            if neighbor_peak < current_peak:
                current_schedule, current_peak = neighbor, neighbor_peak
                if neighbor_peak < best_peak:
                    best_schedule, best_peak = neighbor, neighbor_peak
            else:
                acceptance_prob = np.exp(-(neighbor_peak - current_peak) / temp)
                if random.random() < acceptance_prob:
                    current_schedule, current_peak = neighbor, neighbor_peak
                    
            temp *= cooling_rate
            
        return best_schedule
    
    # ---------- Main Optimization Pipeline ----------
    def innovative_optimization(self):
        """Main two-stage innovative optimization algorithm."""
        # Step 1: Fast construction of high-quality initial solution
        initial_schedule = self.cp_memory_aware_construction()
        
        # For very small graphs, skip refinement
        if self.num_nodes < 20:
            peak_memory = self._compute_peak_memory(initial_schedule)
            return initial_schedule, peak_memory
            
        # Step 2: Simulated annealing fine-tuning
        final_schedule = self.simulated_annealing_optimization(initial_schedule)
        
        # Compute final peak memory
        peak_memory = self._compute_peak_memory(final_schedule)
        
        return final_schedule, peak_memory

# ---------- Main Workflow ----------

def run_innovative_scheduler(task_name: str, nodes_fn: str, edges_fn: str):
    nodes_path = os.path.join(DATA_DIR, nodes_fn)
    edges_path = os.path.join(DATA_DIR, edges_fn)
    
    if not (os.path.exists(nodes_path) and os.path.exists(edges_path)):
        print(f"[Skip] {task_name}: Missing CSV files")
        return None

    df_nodes, succ, pred = load_graph(nodes_path, edges_path)
    scheduler = InnovativeMemoryScheduler(df_nodes, succ, pred)
    
    print(f"[{task_name}] Starting innovative optimization...")
    schedule, peak_memory = scheduler.innovative_optimization()
    
    if not scheduler._is_valid_schedule(schedule):
        print(f"[{task_name}] Warning: Generated schedule is invalid; falling back to greedy result")
        schedule = scheduler.cp_memory_aware_construction()
        peak_memory = scheduler._compute_peak_memory(schedule)
    
    out_sched = os.path.join(OUT_DIR, f"{task_name}_innovative_schedule.txt")
    with open(out_sched, "w", encoding="utf-8") as f:
        for nid in schedule:
            f.write(str(nid) + "\n")
    
    metrics = {
        "task": task_name,
        "num_nodes": len(df_nodes),
        "peak_memory": peak_memory,
        "algorithm": "innovative_two_stage"
    }
    
    metrics_path = os.path.join(OUT_DIR, f"{task_name}_innovative_metrics.json")
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2, ensure_ascii=False)
    
    print(f"[{task_name}] Done! Nodes={len(df_nodes)} PeakMemory={peak_memory}")
    return metrics

def discover_tasks(data_dir: str) -> List[Tuple[str, str, str]]:
    files = os.listdir(data_dir)
    nodes = {f[:-10] for f in files if f.endswith("_Nodes.csv")}
    edges = {f[:-10] for f in files if f.endswith("_Edges.csv")}
    bases = sorted(nodes & edges)
    return [(b, f"{b}_Nodes.csv", f"{b}_Edges.csv") for b in bases]

if __name__ == "__main__":
    candidates = discover_tasks(DATA_DIR)
    results = []
    
    print(os.path.abspath(os.path.curdir))
    print("candidates:", candidates)
    for task_name, nodes_file, edges_file in candidates:
        try:
            starttime = time.time()
            result = run_innovative_scheduler(task_name, nodes_file, edges_file)
            endtime = time.time()
            if result:
                results.append(result)
            print("use time:", endtime - starttime)
            print("results:", results)
        except Exception as e:
            print(f"[{task_name}] Error: {str(e)}")
            results.append({"task": task_name, "error": str(e)})
    print("results:", results)

e:\mywork\liurunhao\NPU\问题一\问题一
candidates: [('Conv_Case0', 'Conv_Case0_Nodes.csv', 'Conv_Case0_Edges.csv'), ('Conv_Case1', 'Conv_Case1_Nodes.csv', 'Conv_Case1_Edges.csv'), ('FlashAttention_Case0', 'FlashAttention_Case0_Nodes.csv', 'FlashAttention_Case0_Edges.csv'), ('FlashAttention_Case1', 'FlashAttention_Case1_Nodes.csv', 'FlashAttention_Case1_Edges.csv'), ('Matmul_Case0', 'Matmul_Case0_Nodes.csv', 'Matmul_Case0_Edges.csv'), ('Matmul_Case1', 'Matmul_Case1_Nodes.csv', 'Matmul_Case1_Edges.csv')]


KeyboardInterrupt: 

In [None]:
import os, heapq, json, pandas as pd, collections
from typing import Dict, List, Tuple

# ---------- Configuration ----------
DATA_DIR = r"./CSV版本"
OUT_DIR = os.path.join(DATA_DIR, "Attachment", "Problem1")
os.makedirs(OUT_DIR, exist_ok=True)

# ---------- Utility Functions ----------

def parse_bufs(x):
    """Parse 'Bufs' column to list[int], robust to '[1,2]' / '1,2' / None."""
    if pd.isna(x):
        return []
    s = str(x)
    nums, cur = [], ''
    for ch in s:
        if ch.isdigit():
            cur += ch
        else:
            if cur:
                nums.append(int(cur))
                cur = ''
    if cur:
        nums.append(int(cur))
    return nums

def load_graph(nodes_csv: str, edges_csv: str):
    """Load nodes & edges; build successor and predecessor lists."""
    df_nodes = pd.read_csv(nodes_csv)

    # Ensure required columns exist
    need_cols = ['Id','Op','BufId','Size','Type','Cycles','Pipe','Bufs']
    for c in need_cols:
        if c not in df_nodes.columns:
            df_nodes[c] = None

    # Normalize data types
    df_nodes['Id'] = pd.to_numeric(df_nodes['Id'], errors='coerce').astype(int)
    for c in ['BufId','Size','Cycles']:
        df_nodes[c] = pd.to_numeric(df_nodes[c], errors='coerce')
    df_nodes['BufsList'] = df_nodes['Bufs'].apply(parse_bufs)

    # Load edges and normalize
    df_edges = pd.read_csv(edges_csv)
    # Handle case-insensitive or variant column names (e.g., StartNodeId / startnodeid)
    cols = {c.lower(): c for c in df_edges.columns}
    if 'startnodeid' not in cols or 'endnodeid' not in cols:
        raise ValueError("Edges CSV must contain columns: StartNodeId, EndNodeId")
    s_col, e_col = cols['startnodeid'], cols['endnodeid']
    df_edges[s_col] = pd.to_numeric(df_edges[s_col], errors='coerce').astype(int)
    df_edges[e_col] = pd.to_numeric(df_edges[e_col], errors='coerce').astype(int)

    # Build graph
    node_ids = set(df_nodes['Id'].tolist())
    succ = {int(nid): [] for nid in node_ids}
    pred = {int(nid): [] for nid in node_ids}
    for _, e in df_edges.iterrows():
        u = int(e[s_col])
        v = int(e[e_col])
        # Skip edges referencing nodes not in df_nodes
        if u not in node_ids or v not in node_ids:
            continue
        succ[u].append(v)
        pred[v].append(u)
    return df_nodes, succ, pred

def memory_delta(row) -> int:
    """Δ(v): +Size for ALLOC, -Size for FREE, 0 otherwise."""
    op = str(row['Op'])
    if op == 'ALLOC':
        return int(row['Size']) if not pd.isna(row['Size']) else 0
    if op == 'FREE':
        return -int(row['Size']) if not pd.isna(row['Size']) else 0
    return 0

# ---------- Problem 1: Memory-First List Scheduling (MF-LS) ----------

def mf_ls_schedule(df_nodes: pd.DataFrame,
                   succ: Dict[int, List[int]],
                   pred: Dict[int, List[int]]) -> Tuple[List[int], int, List[Tuple[int,int]]]:
    delta = {int(r['Id']): memory_delta(r) for _, r in df_nodes.iterrows()}
    op_map = {int(r['Id']): str(r['Op']) for _, r in df_nodes.iterrows()}

    # ✅ FIX: Use set().union(...) to avoid type errors between list and set
    nodes = set().union(delta.keys(), succ.keys(), pred.keys())

    # Ensure indeg covers all nodes from df_nodes
    indeg = {nid: 0 for nid in nodes}
    for v, preds in pred.items():
        indeg[v] = len(preds)

    def pri(nid: int):
        op = op_map.get(nid, '')
        dz = delta.get(nid, 0)
        if op == 'FREE':
            return (0, -abs(dz), nid)  # Prioritize FREE; larger size → higher priority
        elif dz == 0:
            return (1, 0, nid)         # Neutral operations
        else:
            return (2, abs(dz), nid)   # ALLOC last; smaller size preferred

    heap = []
    for nid, d in indeg.items():
        if d == 0:
            heapq.heappush(heap, (pri(nid), nid))

    schedule, mem_series = [], []
    cur_mem, peak = 0, 0
    visited = 0
    while heap:
        _, nid = heapq.heappop(heap)
        schedule.append(nid)
        visited += 1
        cur_mem += delta.get(nid, 0)
        if cur_mem < 0:  # Prevent negative memory
            cur_mem = 0
        peak = max(peak, cur_mem)
        mem_series.append((nid, cur_mem))
        for v in succ.get(nid, []):
            indeg[v] -= 1
            if indeg[v] == 0:
                heapq.heappush(heap, (pri(v), v))

    if visited != len(df_nodes):
        missing = set(df_nodes['Id']) - set(schedule)
        raise RuntimeError(f"Graph not fully scheduled (possible cycle or isolated nodes not added to ready set). Missing {len(missing)} nodes.")
    return schedule, peak, mem_series

# ---------- Optional: Makespan Calculation ----------
def compute_makespan(df_nodes: pd.DataFrame, succ: Dict[int, List[int]], schedule: List[int]) -> int:
    duration, resource = {}, {}
    for _, r in df_nodes.iterrows():
        nid = int(r['Id'])
        cyc = int(r['Cycles']) if not pd.isna(r['Cycles']) else 0
        if str(r['Op']) in ('ALLOC', 'FREE'):
            cyc = 0
        duration[nid] = max(0, cyc)
        resource[nid] = str(r['Pipe']) if not pd.isna(r['Pipe']) else 'MGMT'

    preds = {int(r['Id']): [] for _, r in df_nodes.iterrows()}
    for u, vs in succ.items():
        for v in vs:
            preds.setdefault(v, []).append(u)
            preds.setdefault(u, preds.get(u, []))

    res_end = collections.defaultdict(int)
    start, end = {}, {}
    for nid in schedule:
        est = max([0] + [end.get(p, 0) for p in preds.get(nid, [])])
        est = max(est, res_end[resource[nid]])
        start[nid] = est
        end[nid] = est + duration[nid]
        res_end[resource[nid]] = end[nid]
    return max(end.values()) if end else 0

# ---------- Main Workflow (Callable from Notebook) ----------

def run_task(task_name: str, nodes_fn: str, edges_fn: str):
    nodes_path = os.path.join(DATA_DIR, nodes_fn)
    edges_path = os.path.join(DATA_DIR, edges_fn)
    if not (os.path.exists(nodes_path) and os.path.exists(edges_path)):
        print(f"[Skip] {task_name}: Missing CSV files -> {nodes_path} or {edges_path}")
        return None

    df_nodes, succ, pred = load_graph(nodes_path, edges_path)
    schedule, peak, mem_series = mf_ls_schedule(df_nodes, succ, pred)
    makespan = compute_makespan(df_nodes, succ, schedule)  # Optional

    # Save schedule
    out_sched = os.path.join(OUT_DIR, f"{task_name}_schedule.txt")
    with open(out_sched, "w", encoding="utf-8") as f:
        for nid in schedule:
            f.write(str(nid) + "\n")

    # Save metrics
    metrics = {
        "task": task_name,
        "num_nodes": int(len(df_nodes)),
        "peak_memory": int(peak),
        "makespan_cycles": int(makespan)
    }
    with open(os.path.join(OUT_DIR, f"{task_name}_metrics.json"), "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2, ensure_ascii=False)

    print(f"[{task_name}] Nodes={metrics['num_nodes']} PeakMemory={metrics['peak_memory']} Makespan={metrics['makespan_cycles']}")
    return metrics

# ---------- Run All Tasks ----------
candidates = [
    ("Matmul_Case0", "Matmul_Case0_Nodes.csv", "Matmul_Case0_Edges.csv"),
    ("FlashAttention_Case0", "FlashAttention_Case0_Nodes.csv", "FlashAttention_Case0_Edges.csv"),
    ("FlashAttention_Case1", "FlashAttention_Case1_Nodes.csv", "FlashAttention_Case1_Edges.csv"),
    ("Conv_Case1", "Conv_Case1_Nodes.csv", "Conv_Case1_Edges.csv"),
    ("Conv_Case0", "Conv_Case0_Nodes.csv", "Conv_Case0_Edges.csv"),  # Will be skipped if Edges missing
]

results = []
for t, n, e in candidates:
    res = run_task(t, n, e)
    if res:
        results.append(res)

results

[Matmul_Case0] 节点数=4160 峰值驻留=131328 Makespan=148954
[FlashAttention_Case0] 节点数=1716 峰值驻留=42248 Makespan=58370
[FlashAttention_Case1] 节点数=6952 峰值驻留=171664 Makespan=204923
[Conv_Case1] 节点数=36086 峰值驻留=476790 Makespan=908159
[Conv_Case0] 节点数=2580 峰值驻留=62778 Makespan=471605


[{'task': 'Matmul_Case0',
  'num_nodes': 4160,
  'peak_memory': 131328,
  'makespan_cycles': 148954},
 {'task': 'FlashAttention_Case0',
  'num_nodes': 1716,
  'peak_memory': 42248,
  'makespan_cycles': 58370},
 {'task': 'FlashAttention_Case1',
  'num_nodes': 6952,
  'peak_memory': 171664,
  'makespan_cycles': 204923},
 {'task': 'Conv_Case1',
  'num_nodes': 36086,
  'peak_memory': 476790,
  'makespan_cycles': 908159},
 {'task': 'Conv_Case0',
  'num_nodes': 2580,
  'peak_memory': 62778,
  'makespan_cycles': 471605}]

In [None]:
import pandas as pd
import heapq
from collections import defaultdict

# Directory path for input files
file_dir = r"./CSV版本"

# --------------------------
# Part 1: Data Preprocessing (unchanged)
# --------------------------
print("Starting data preprocessing...")

# Load node and edge files
nodes_df = pd.read_csv(f'{file_dir}\\Matmul_Case0_Nodes.csv')
edges_df = pd.read_csv(f'{file_dir}\\Matmul_Case0_Edges.csv')

# 1. Build node attribute table (keep Cycles field)
node_attributes = {}
for _, row in nodes_df.iterrows():
    node_id = int(row['Id'])
    op_type = row['Op']
    buf_id = int(row['BufId']) if pd.notna(row['BufId']) else None
    size = int(row['Size']) if pd.notna(row['Size']) else None
    cache_type = row['Type'] if pd.notna(row['Type']) else None
    cycles = int(row['Cycles']) if pd.notna(row['Cycles']) else 0
    node_attributes[node_id] = {
        'id': node_id,
        'op': op_type,
        'buf_id': buf_id,
        'size': size,
        'cache_type': cache_type,
        'cycles': cycles,
        'predecessors': [],
        'successors': []
    }

# 2. Process edge information (unchanged)
in_degree = defaultdict(int)
for _, row in edges_df.iterrows():
    src = int(row['StartNodeId'])
    dst = int(row['EndNodeId'])
    node_attributes[dst]['predecessors'].append(src)
    node_attributes[src]['successors'].append(dst)
    in_degree[dst] += 1
    if src not in in_degree:
        in_degree[src] = 0

# 3. Build buffer lifecycle mapping (unchanged)
buf_lifecycle = defaultdict(dict)
for node_id, attr in node_attributes.items():
    if attr['op'] == 'ALLOC' and attr['buf_id'] is not None:
        buf_id = attr['buf_id']
        buf_lifecycle[buf_id]['alloc'] = node_id
        buf_lifecycle[buf_id]['uses'] = []
    elif attr['op'] == 'FREE' and attr['buf_id'] is not None:
        buf_id = attr['buf_id']
        buf_lifecycle[buf_id]['free'] = node_id

# Populate usage nodes (unchanged)
for buf_id, lifecycle in buf_lifecycle.items():
    if 'alloc' not in lifecycle or 'free' not in lifecycle:
        continue
    alloc_id = lifecycle['alloc']
    free_id = lifecycle['free']
    visited = set()
    queue = [alloc_id]
    while queue:
        current = queue.pop(0)
        if current == free_id:
            break
        if current in visited:
            continue
        visited.add(current)
        if current != alloc_id:
            lifecycle['uses'].append(current)
        for succ in node_attributes[current]['successors']:
            if succ not in visited:
                queue.append(succ)

# 4. Save preprocessing results (unchanged)
node_attr_df = pd.DataFrame(node_attributes.values())
node_attr_df.to_csv(f'{file_dir}\\Matmul_Case0_node_attributes.csv', index=False)

in_degree_df = pd.DataFrame(list(in_degree.items()), columns=['NodeId', 'InDegree'])
in_degree_df.to_csv(f'{file_dir}\\Matmul_Case0_in_degree.csv', index=False)

buf_lifecycle_list = []
for buf_id, data in buf_lifecycle.items():
    record = {
        'BufId': buf_id,
        'AllocNode': data['alloc'],
        'UseNodes': ','.join(str(x) for x in data['uses']) if data['uses'] else '',
        'FreeNode': data['free']
    }
    buf_lifecycle_list.append(record)
buf_lifecycle_df = pd.DataFrame(buf_lifecycle_list)
buf_lifecycle_df.to_csv(f'{file_dir}\\Matmul_Case0_buf_lifecycle.csv', index=False)

print('Preprocessing completed. Results saved to the following CSV files:')
print(f'{file_dir}\\Matmul_Case0_node_attributes.csv')
print(f'{file_dir}\\Matmul_Case0_in_degree.csv')
print(f'{file_dir}\\Matmul_Case0_buf_lifecycle.csv')

# --------------------------
# Part 2: Final Fixed Scheduling Sequence Generation
# --------------------------
print("\nGenerating final fixed scheduling sequence...")

# 1. Load preprocessed data (unchanged)
node_attr_df = pd.read_csv(f'{file_dir}\\Matmul_Case0_node_attributes.csv')
node_attributes = {}
for _, row in node_attr_df.iterrows():
    node_id = int(row['id'])
    node_attributes[node_id] = {
        'id': node_id,
        'op': row['op'],
        'buf_id': int(row['buf_id']) if pd.notna(row['buf_id']) else None,
        'size': int(row['size']) if pd.notna(row['size']) else None,
        'cache_type': row['cache_type'] if pd.notna(row['cache_type']) else None,
        'cycles': int(row['cycles']) if pd.notna(row['cycles']) else 0,
        'predecessors': eval(row['predecessors']),
        'successors': eval(row['successors'])
    }

in_degree_df = pd.read_csv(f'{file_dir}\\Matmul_Case0_in_degree.csv')
in_degree = {int(row['NodeId']): int(row['InDegree']) for _, row in in_degree_df.iterrows()}

buf_lifecycle_df = pd.read_csv(f'{file_dir}\\Matmul_Case0_buf_lifecycle.csv')
buf_lifecycle = {}
for _, row in buf_lifecycle_df.iterrows():
    buf_id = int(row['BufId'])
    use_nodes_str = str(row['UseNodes']) if pd.notna(row['UseNodes']) else ''
    uses = []
    if use_nodes_str and use_nodes_str != 'nan' and use_nodes_str != '0':
        uses = list(map(int, use_nodes_str.split(',')))
    buf_lifecycle[buf_id] = {
        'alloc': int(row['AllocNode']),
        'uses': uses,
        'free': int(row['FreeNode'])
    }

# 2. Initialize scheduling data structures (with node status tracking)
schedule = []
candidate_heap = []
current_cache = 0
max_cache = 0
buf_in_use = set()
current_step = 0

# New 1: Node status tracking (to avoid duplicate scheduling or re-adding)
node_status = {node_id: 'pending' for node_id in node_attributes}  # pending/processing/completed
retry_count = defaultdict(int)  # Retry count per node (max 3 times)
MAX_RETRY = 3  # Maximum retries before marking a node as invalid

# L0 cache progress tracking (unchanged)
l0_progress = {
    'L0A': {'in_use': False, 'current_op': None, 'progress': 0},
    'L0B': {'in_use': False, 'current_op': None, 'progress': 0},
    'L0C': {'in_use': False, 'current_op': None, 'progress': 0}
}

# Future demand queue parameter (unchanged)
K = 15


# --------------------------
# Core Fix 1: Multi-dimensional Priority Calculation (unchanged)
# --------------------------
def calculate_priority(node, current_step):
    priority = 0
    buf_id = node['buf_id']
    op_type = node['op']

    if op_type == 'FREE' and buf_id is not None:
        buf_size = node['size'] if (node['size'] is not None and buf_id in buf_lifecycle) else 0
        priority += (100 + buf_size) * 0.4

    elif op_type in ['ALLOC', 'COPY_IN', 'COPY_OUT', 'MMAD', 'SOFTMAX'] and buf_id in buf_lifecycle:
        lifecycle = buf_lifecycle[buf_id]
        if 'free' in lifecycle and 'alloc' in lifecycle:
            total_lifecycle = lifecycle['free'] - lifecycle['alloc']
            if total_lifecycle > 0:
                current_reside = current_step - lifecycle['alloc']
                lifecycle_ratio = min(current_reside / total_lifecycle, 1.0)
                priority += lifecycle_ratio * 100 * 0.3

    if op_type not in ['ALLOC', 'FREE']:
        priority += node['cycles'] * 0.2

    if node['cache_type'] in ['L0A', 'L0B', 'L0C']:
        priority += 50 * 0.1

    return priority


# --------------------------
# Core Fix 2: Conflict Prediction Function (unchanged)
# --------------------------
def predict_cache_conflict(future_alloc_list, cache_type, current_cache_used):
    cache_capacity_map = {'L1': 4096, 'UB': 1024, 'L0A': 256, 'L0B': 256, 'L0C': 512}
    if cache_type not in cache_capacity_map:
        return False
    cache_capacity = cache_capacity_map[cache_type]
    future_total = sum(alloc['size'] for alloc in future_alloc_list if alloc['cache_type'] == cache_type)
    return (current_cache_used + future_total) > cache_capacity


# Initialize candidate heap (only add 'pending' nodes)
for node_id, degree in in_degree.items():
    if degree == 0 and node_status[node_id] == 'pending':
        node = node_attributes[node_id]
        priority = calculate_priority(node, current_step=0)
        heapq.heappush(candidate_heap, (-priority, node_id))
        node_status[node_id] = 'processing'  # Mark as processing to avoid re-adding


# --------------------------
# Core Fix 3: Main Scheduling Loop (resolves infinite loop)
# --------------------------
# Track total number of nodes to detect completion
total_nodes = len(node_attributes)
while candidate_heap:
    # Termination condition 1: All nodes scheduled
    if len(schedule) >= total_nodes:
        print(f"Early termination: All {total_nodes} nodes have been scheduled.")
        break

    current_step = len(schedule)

    # Step 1: Update future demand queue (only consider pending/processing nodes)
    temp_heap = candidate_heap.copy()
    future_alloc_queue = []
    count = 0
    while temp_heap and count < K:
        try:
            neg_prio, node_id = heapq.heappop(temp_heap)
        except IndexError:
            break
        if node_status[node_id] != 'completed':  # Only consider unscheduled nodes
            node = node_attributes[node_id]
            if node['op'] == 'ALLOC' and node['cache_type'] is not None and node['buf_id'] is not None:
                future_alloc_queue.append({
                    'buf_id': node['buf_id'],
                    'size': node['size'] if node['size'] is not None else 0,
                    'cache_type': node['cache_type']
                })
        count += 1

    # Step 2: Pop highest-priority valid node (skip completed or over-retried nodes)
    valid_node = False
    while candidate_heap and not valid_node:
        try:
            neg_prio, current_node_id = heapq.heappop(candidate_heap)
        except IndexError:
            break
        # Skip if already completed
        if node_status[current_node_id] == 'completed':
            continue
        # Skip if retry limit exceeded
        if retry_count[current_node_id] > MAX_RETRY:
            print(f"Node {current_node_id} exceeded max retry count ({MAX_RETRY}), marked as invalid.")
            node_status[current_node_id] = 'completed'
            continue
        # Valid node found
        current_node = node_attributes[current_node_id]
        valid_node = True

    # Termination condition 2: No valid nodes left (heap exhausted)
    if not valid_node:
        print(f"Terminated: No valid nodes available for scheduling ({len(schedule)}/{total_nodes} scheduled).")
        break

    # Step 3: L0 cache fine-grained scheduling check (with retry limit)
    l0_skip = False
    if current_node['op'] == 'ALLOC' and current_node['cache_type'] in ['L0A', 'L0B', 'L0C']:
        l0_type = current_node['cache_type']
        l0_info = l0_progress[l0_type]
        buf_id = current_node['buf_id']
        if buf_id is None:
            retry_count[current_node_id] += 1
            heapq.heappush(candidate_heap, (neg_prio, current_node_id))
            l0_skip = True
            continue

        if l0_info['in_use']:
            if l0_info['progress'] >= 90:
                free_node_id = next((n for n in node_attributes if
                                     node_attributes[n]['op'] == 'FREE' and
                                     node_attributes[n]['buf_id'] == buf_id and
                                     node_status[n] == 'pending'), None)  # Only look for pending FREE nodes
                if free_node_id and in_degree[free_node_id] == 0:
                    free_node = node_attributes[free_node_id]
                    free_prio = calculate_priority(free_node, current_step)
                    # Execute FREE node
                    if buf_id in buf_in_use:
                        current_cache -= free_node['size'] if free_node['size'] is not None else 0
                        buf_in_use.remove(buf_id)
                    schedule.append(free_node_id)
                    node_status[free_node_id] = 'completed'  # Mark as completed
                    # Reset L0 state
                    l0_progress[l0_type]['in_use'] = False
                    l0_progress[l0_type]['current_op'] = None
                    l0_progress[l0_type]['progress'] = 0
                else:
                    retry_count[current_node_id] += 1
                    heapq.heappush(candidate_heap, (neg_prio, current_node_id))
                    l0_skip = True
            else:
                retry_count[current_node_id] += 1
                heapq.heappush(candidate_heap, (neg_prio, current_node_id))
                l0_skip = True
        else:
            if buf_id in buf_lifecycle:
                use_node_id = next((n for n in buf_lifecycle[buf_id].get('uses', [])
                                    if node_attributes[n]['op'] == 'MMAD' and
                                    node_status[n] == 'pending'), None)  # Only pending MMAD nodes
                if use_node_id:
                    l0_progress[l0_type]['in_use'] = True
                    l0_progress[l0_type]['current_op'] = use_node_id
    if l0_skip:
        continue

    # Step 4: Buffer conflict prediction (with retry limit)
    alloc_skip = False
    if current_node['op'] == 'ALLOC' and current_node['cache_type'] is not None and current_node['buf_id'] is not None:
        cache_type = current_node['cache_type']
        buf_id = current_node['buf_id']
        buf_size = current_node['size'] if current_node['size'] is not None else 0

        # Compute current used capacity for this cache type
        current_cache_used = 0
        for bid in buf_in_use:
            if bid in buf_lifecycle and buf_lifecycle[bid].get('alloc') in node_attributes:
                alloc_node = node_attributes[buf_lifecycle[bid]['alloc']]
                if alloc_node['cache_type'] == cache_type and alloc_node['size'] is not None:
                    current_cache_used += alloc_node['size']

        if predict_cache_conflict(future_alloc_queue, cache_type, current_cache_used):
            # Find pending FREE nodes that can release space
            free_node_candidates = [
                n for n in node_attributes
                if (node_attributes[n]['op'] == 'FREE' and
                    node_attributes[n]['cache_type'] == cache_type and
                    node_attributes[n]['buf_id'] is not None and
                    in_degree[n] == 0 and
                    node_status[n] == 'pending')
            ]

            if free_node_candidates:
                def get_free_size(node_id):
                    node = node_attributes[node_id]
                    return node['size'] if node['size'] is not None else 0

                free_node_id = max(free_node_candidates, key=get_free_size)
                free_node = node_attributes[free_node_id]
                free_buf_id = free_node['buf_id']

                # Release buffer
                if free_buf_id in buf_in_use:
                    current_cache -= get_free_size(free_node_id)
                    buf_in_use.remove(free_buf_id)
                    if free_node['cache_type'] in ['L0A', 'L0B', 'L0C']:
                        l0_progress[free_node['cache_type']]['in_use'] = False
                        l0_progress[free_node['cache_type']]['current_op'] = None
                        l0_progress[free_node['cache_type']]['progress'] = 0

                # Schedule FREE node
                schedule.append(free_node_id)
                node_status[free_node_id] = 'completed'
                # Re-add current ALLOC node (increment retry)
                retry_count[current_node_id] += 1
                heapq.heappush(candidate_heap, (neg_prio, current_node_id))
                alloc_skip = True
    if alloc_skip:
        continue

    # Step 5: Execute current node (mark as completed immediately)
    schedule.append(current_node_id)
    node_status[current_node_id] = 'completed'  # Critical: prevent reprocessing
    retry_count[current_node_id] = 0  # Reset retry count

    # Step 6: Update L0 operation progress (unchanged)
    for l0_type in l0_progress:
        l0_info = l0_progress[l0_type]
        if l0_info['in_use'] and l0_info['current_op'] == current_node_id:
            step_progress = 100 / max(current_node['cycles'], 1)
            l0_info['progress'] = min(l0_info['progress'] + step_progress, 100)
            buf_id = current_node['buf_id']
            if buf_id is not None and buf_id in buf_lifecycle and 'free' in buf_lifecycle[buf_id]:
                free_node_id = buf_lifecycle[buf_id]['free']
                if free_node_id in node_attributes and in_degree[free_node_id] == 0 and node_status[free_node_id] == 'pending':
                    free_node = node_attributes[free_node_id]
                    free_prio = calculate_priority(free_node, current_step + 1)
                    heapq.heappush(candidate_heap, (-free_prio, free_node_id))
                    node_status[free_node_id] = 'processing'

    # Step 7: Update cache state (unchanged)
    buf_id = current_node['buf_id']
    if current_node['op'] == 'ALLOC' and buf_id is not None:
        if buf_id not in buf_in_use:
            buf_size = current_node['size'] if current_node['size'] is not None else 0
            current_cache += buf_size
            buf_in_use.add(buf_id)

    elif current_node['op'] == 'FREE' and buf_id is not None:
        if buf_id in buf_in_use:
            buf_size = current_node['size'] if current_node['size'] is not None else 0
            current_cache -= buf_size
            buf_in_use.remove(buf_id)

    # Step 8: Update peak memory usage (unchanged)
    if current_cache > max_cache:
        max_cache = current_cache

    # Step 9: Update successor in-degrees (only add pending nodes)
    for succ_id in current_node['successors']:
        if succ_id not in node_attributes or node_status[succ_id] == 'completed':
            continue  # Skip missing or completed successors
        in_degree[succ_id] -= 1
        if in_degree[succ_id] == 0 and node_status[succ_id] == 'pending':
            succ_node = node_attributes[succ_id]
            succ_prio = calculate_priority(succ_node, current_step + 1)
            heapq.heappush(candidate_heap, (-succ_prio, succ_id))
            node_status[succ_id] = 'processing'  # Mark as processing


# 4. Save results and output summary (with completeness validation)
schedule_df = pd.DataFrame({'NodeId': schedule})
schedule_df.to_csv(f'{file_dir}\\Matmul_Case0_schedule_final.csv', index=False)

# Output scheduling completeness info
completed_nodes = set(schedule)
missing_nodes = [node_id for node_id in node_attributes if node_id not in completed_nodes]
print(f"\nScheduling completed!")
print(f"Total nodes: {total_nodes} | Scheduled: {len(schedule)} | Unscheduled: {len(missing_nodes)}")
if missing_nodes:
    print(f"Unscheduled node IDs (first 10): {missing_nodes[:10]}...")
print(f"Peak memory footprint: {max_cache}")
print(f"Final schedule saved to: {file_dir}\\Matmul_Case0_schedule_final.csv")
print("\nSchedule preview (first 20 nodes):")
print(schedule[:20] if len(schedule) >= 20 else schedule)
print("\nSchedule preview (last 20 nodes):")
print(schedule[-20:] if len(schedule) >= 20 else schedule)

开始数据预处理...
预处理完成，结果保存到以下CSV文件：
./CSV版本\Matmul_Case0_node_attributes.csv
./CSV版本\Matmul_Case0_in_degree.csv
./CSV版本\Matmul_Case0_buf_lifecycle.csv

开始生成最终修复版调度序列...

最终调度完成！
总节点数：4160 | 已调度节点数：4160 | 未调度节点数：0
最大缓存驻留容量: 27904
最终调度序列已保存至: ./CSV版本\Matmul_Case0_schedule_final.csv

调度序列预览（前20个节点）:
[0, 5, 6, 14, 15, 23, 24, 32, 33, 41, 42, 50, 51, 59, 60, 68, 69, 74, 77, 78]

调度序列预览（后20个节点）:
[2933, 2935, 4035, 2934, 2936, 3146, 2937, 4156, 4157, 2938, 2940, 4038, 2939, 2941, 3149, 2942, 4158, 4159, 2943, 4143]
