In [3]:
from dataclasses import dataclass
from typing import Any, List

from collections import namedtuple
from snake_egg import EGraph, Rewrite, Var, vars
import re
import networkx as nx
import matplotlib.pyplot as plt
from networkx.algorithms import isomorphism
import json
import subprocess
import time
from collections import deque
from itertools import count

import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from math import ceil, log2

In [4]:
class Node:
    def __init__(self, name=None, bitwidth=None, operation=None, operands=None, 
                 idNum=None, value=-1, pos=None, FuncIO=None, start=None, width=None, array_sizes=None, indices=None):
        self.name = name
        self.bitwidth = bitwidth
        self.operation = operation
        self.operands = operands
        self.idNum = idNum
        self.value = value
        self.pos = pos
        self.FuncIO = FuncIO
        self.start = start
        self.width = width
        self.array_sizes = array_sizes
        self.indices = indices

    def __repr__(self):
        return f"Node(name={self.name}, bitwidth={self.bitwidth}, operation={self.operation}, operands={self.operands}, id={self.idNum}, value={self.value}, pos={self.pos}, FuncIO={self.FuncIO}, start={self.start}, width={self.width}, array_sizes={self.array_sizes}, indices={self.indices})"

def node_to_dict(node):
    return {
        "OperationName": node.name,
        "BitWidth": node.bitwidth,
        "OperationType": node.operation,
        "Operands": node.operands,
        "idNum": node.idNum,
        "Value": node.value,
        "Pos": node.pos,
        "FuncIO": node.FuncIO,
        "Start": node.start,
        "Width": node.width,
        "ArraySize": node.array_sizes,
        "Indices": node.indices
    }
    
def ParseIRFile(ir_text):
    nodes_dict = {}
    lines = ir_text.strip().split("\n")[:]
    return lines

def LineToDict(line):
#     print(line)
    nodes = []
    # Check for function header
    FuncIO = "Not"
    if "fn" in line:
        FuncIO = "In"
        signature_pattern = re.compile(r"fn\s+(\w+)\((.*?)\)\s*->\s*(.*?)\s*\{")
        signature_match = signature_pattern.search(line)
        if signature_match:
            func_name = signature_match.group(1)
            params = signature_match.group(2).split(',')
            for param in params:
                param_name, param_type = param.split(':')
                param_name = param_name.strip()
                array_sizes = [int(size) for size in re.findall(r'\[(\d+)\]', param_type)]
                if array_sizes == []:
                    array_sizes = None
                bitwidth = int(array_sizes.pop(0))  # The first size is the bitwidth
                node = Node(name=param_name, bitwidth=bitwidth, array_sizes=array_sizes, FuncIO="In")
                nodes.append(node)
        return nodes

    if "ret" in line:
        FuncIO = "Out"
        line = line[4:]
        
    if line == "}":
        return None

    # Extract node name
#     name = re.search(r"(\w+\.\d+):", line).group(1)
    name = re.search(r"(\w+(\.\d+)?)\:", line).group(1)

    # Extract bitwidth and array size
    array_pattern = re.compile(r"bits\[(\d+)\]((?:\[\d+\])*)")
    array_match = array_pattern.search(line)
    if array_match:
        bitwidth = int(array_match.group(1))
        array_sizes_str = array_match.group(2)
        array_sizes = [int(size) for size in re.findall(r'\[(\d+)\]', array_sizes_str)]
        if array_sizes == []:
            array_sizes = None
    
    # Extract operation
    operation = re.search(r"= (\w+)", line).group(1)

    # Extract operands; 
    # Eidt to ignore other argument keywords
    operands = re.findall(r"(\w+\.\d+|\w+)", line.split("=")[1])
    operands = [op for op in operands if op not in [operation, "value", "id", "pos", "start", "width", "indices"]]

    # Extract indices: only applies for array_indice node
    indices = None
    indices_match = re.search(r"indices=\[(.*?)\]", line)
    if indices_match:
        # Get the string of indices
        indices_str = indices_match.group(1)
        # Find all occurrences of literals or node names in the indices string
        indices = re.findall(r"(\w+\.\d+|\w+)", indices_str)
    
    # Extract id
    idNum = int(re.search(r"id=(\d+)", line).group(1))

    # Extract value (if present)
    value_match = re.search(r"value=(\d+)", line)
    value = int(value_match.group(1)) if value_match else -1

    # Extract pos (if present)
    pos_match = re.search(r"pos=\[\((\d+,\d+,\d+)\)\]", line)
    pos = tuple(map(int, pos_match.group(1).split(","))) if pos_match else None

    # Extract start and width for bit_slice (if present)
    start = None
    width = None
    if operation == "bit_slice":
        start = int(re.search(r"start=(\d+)", line).group(1))
        width = int(re.search(r"width=(\d+)", line).group(1))

    # Create Node object
    node = Node(name, bitwidth, operation, operands, idNum, value, pos, FuncIO, start, width, array_sizes, indices)
    
    # Add to node list
    nodes.append(node)
    return nodes

def DictToGraph(G, NodeDict):
    # This function turns the node dictionary to directional graph
    NodeList = []
    EdgeList = []
    for NodeName in NodeDict:
        if NodeDict[NodeName].FuncIO != "In": 
            # This if is for handling the function top input, they are treated as nodes as well.
            NodeList.append((NodeDict[NodeName].idNum, node_to_dict(NodeDict[NodeName])))
            for ParentName in NodeDict[NodeName].operands:
                if NodeDict[ParentName].FuncIO != "In":
                    EdgeList.append((NodeDict[ParentName].idNum, NodeDict[NodeName].idNum))
                else: 
                    EdgeList.append((NodeDict[ParentName].name, NodeDict[NodeName].idNum))
            if NodeDict[NodeName].indices != None:
                for ParentName in NodeDict[NodeName].indices:
                    if NodeDict[ParentName].FuncIO != "In":
                        EdgeList.append((NodeDict[ParentName].idNum, NodeDict[NodeName].idNum))
                    else: 
                        EdgeList.append((NodeDict[ParentName].name, NodeDict[NodeName].idNum))
        else:
            NodeList.append((NodeDict[NodeName].name, node_to_dict(NodeDict[NodeName])))
    G.add_nodes_from(NodeList)
    G.add_edges_from(EdgeList)
    return G
    
    
def slice_ir_by_function(ir_content):
    # Regular expression to match function definitions with and without "top"
    fn_pattern = re.compile(r"(?:top\s+)?fn\s+(\w+)\((.*?)\)\s*->\s*(.*?)\s*\{")

    # Split the content by lines for processing
    lines = ir_content.split('\n')

    # Dictionary to store each function's IR content
    functions_dict = {}

    # Buffer to store current function lines
    current_fn = None
    current_fn_lines = []

    # Iterate through each line
    for line in lines:
        # Check if the line starts a new function definition
        match = fn_pattern.match(line)
        if match:
            # If there is a current function being processed, save it
            if current_fn:
                functions_dict[current_fn] = '\n'.join(current_fn_lines)
                current_fn_lines = []
            # Start a new function
            current_fn = match.group(1)
        # Add the line to the current function's lines
        if current_fn:
            current_fn_lines.append(line)
        # Check if the line ends a function definition
        if line.strip() == '}':
            # Save the current function and reset
            if current_fn:
                functions_dict[current_fn] = '\n'.join(current_fn_lines)
                current_fn = None
                current_fn_lines = []

    # Return the dictionary with function names as keys and IR content as values
    return functions_dict



In [5]:
def count_registers_in_block(ir_content):
    # Flag to check if we are inside a 'block' section
    in_block_section = False
    register_count = 0

    # Splitting the content into lines
    lines = ir_content.splitlines()

    for line in lines:
        # Check if the 'block' section starts
        if line.strip().startswith("block "):
            in_block_section = True
        # Check if the 'block' section ends
        elif line.strip() == "}" and in_block_section:
            in_block_section = False
        # Count the registers if inside a 'block' section
        elif in_block_section and line.strip().startswith("reg "):
            register_count += 1

    return register_count

def ReadScheduleIR(file_path):
    TopFunctionName = None
    with open(file_path, 'r') as f:
        ir_content = f.read()
        ir_dict = slice_ir_by_function(ir_content)
    FuncNodeDict = {}
    for fn_name, fn_content in ir_dict.items():
        NodeDict = {}
        Lines = ParseIRFile(fn_content)
        for Line in Lines:    
            NodeList = LineToDict(Line)
            if NodeList != None:
                for NodeObj in NodeList:
                    NodeDict[NodeObj.name] = NodeObj
            if "top" in Line:
                TopFunctionName = fn_name
        FuncNodeDict[fn_name] = NodeDict
    SubCounter = 0
    NodeCounter = 0
    JsonOutDict = {}
    G_goble = nx.DiGraph()
    for fn_name, fn_nodes in FuncNodeDict.items():
        G = nx.DiGraph()
        G_goble = DictToGraph(G_goble, fn_nodes)
        
    register_count = count_registers_in_block(ir_content)
    return G_goble, register_count


def read_SDC_pipeline_result(file_path):
    max_stage_latency = 0
    with open(file_path, 'r') as file:
        file_contents = file.read()

    schedule_dict = {}
    current_stage = None
    current_node = None
    
    # Regular expressions to match the lines
    function_pattern = re.compile(r'^function: "(.*)"')
    stage_pattern = re.compile(r'^\s*stage: (\d+)')
    node_pattern = re.compile(r'^\s*node: "(.*)"')
    node_delay_pattern = re.compile(r'^\s*node_delay_ps: (\d+)')
    path_delay_pattern = re.compile(r'^\s*path_delay_ps: (\d+)')

    for line in file_contents.splitlines():
        # Check for function
        function_match = function_pattern.match(line)
        if function_match:
            schedule_dict['function'] = function_match.group(1)
            schedule_dict['stages'] = []
            continue
        
        # Check for stage
        stage_match = stage_pattern.match(line)
        if stage_match:
            current_stage = {'stage': int(stage_match.group(1)), 'timed_nodes': []}
            schedule_dict['stages'].append(current_stage)
            continue
        
        # Check for node
        node_match = node_pattern.match(line)
        if node_match:
            current_node = {'node': node_match.group(1)}
            current_stage['timed_nodes'].append(current_node)
            continue
        
        # Check for node delay
        node_delay_match = node_delay_pattern.match(line)
        if node_delay_match:
            current_node['node_delay_ps'] = int(node_delay_match.group(1))
            continue
        
        # Check for path delay
        path_delay_match = path_delay_pattern.match(line)
        if path_delay_match:
            cur_path_delay = int(path_delay_match.group(1))
            current_node['path_delay_ps'] = cur_path_delay
            if cur_path_delay > max_stage_latency:
                max_stage_latency = cur_path_delay
            continue
        
    return schedule_dict, max_stage_latency

def register_SDC_result(G, schedule_dict):
    stage_num = 0
    for cur_dict in schedule_dict['stages']:
        stage_num += 1
        cur_stage = cur_dict['stage']
        node_list = cur_dict['timed_nodes']
        for cur_node in node_list:
            # Check if the last part is an integer
            node_name_parts = cur_node['node'].split('.')
            if node_name_parts[-1].isdigit():
                # If it is, use it as the ID
                cur_node_id = int(node_name_parts[-1])
            else:
                # If not, use the entire node name as the ID
                cur_node_id = cur_node['node']
            
            G.nodes[cur_node_id]['node_delay_ps'] = cur_node['node_delay_ps']
            G.nodes[cur_node_id]['path_delay_ps'] = cur_node['path_delay_ps']
            G.nodes[cur_node_id]['stage'] = cur_stage
    return G, stage_num


In [6]:
def run_unify_name(ir_input_path, ir_unify_out):
    command = [
        "/home/miao/xls/bazel-bin/xls/tools/UnifyName",
        ir_input_path,
        ir_unify_out
    ]
    command_str = " ".join(command)
    try:
        result = subprocess.run(command_str, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        result.check_returncode()  # This will raise an exception if the return code was non-zero
        print("Name Unification Run Done.")
    except subprocess.CalledProcessError as e:
        print("An error occurred while running the command.")
        print("Error message:", e.stderr)
    return

def run_rewriter(command_executable, ir_input_path, json_output_path, ir_output_path):
    print("\nRunning Standalone Rewriter")
    command = f"{command_executable} {ir_input_path} {json_output_path} {ir_output_path}"
    try:
        result = subprocess.run(command, shell=True, capture_output=True, text=True)
        result.check_returncode()
        print("Rewriter Run Done")
#         print("Output:", result.stdout)
    except subprocess.CalledProcessError as e:
        print("An error occurred while running the command.")
        print("Error message:", e.stderr)
    return

def run_sdc_scheduler(ir_output_path, output_verilog_path, schedule_result_path, output_schedule_ir_path, 
                      delay_model, TopFunctionName, clock_period_ps, clock_margin_precent, 
                      pipeline_stages, period_relaxation_percent):
    command = [
            "/home/miao/xls/bazel-bin/xls/tools/codegen_main",
            ir_output_path,
            '--generator=pipeline',  
            f'--delay_model={delay_model}',
            '--module_name=xls_test',  # Assuming module_name is static
            f'--top={TopFunctionName}',
            f'--output_verilog_path={output_verilog_path}',  
            f'--output_schedule_path={schedule_result_path}',
            f'--output_schedule_ir_path={output_schedule_ir_path}'  
        ]
        
    if clock_period_ps != None:
        command.append(f'--clock_period_ps={clock_period_ps}')
        if clock_margin_precent != None:
            command.append(f'--clock_margin_percent={clock_margin_precent}')
    elif pipeline_stages != None:
        command.append(f'--pipeline_stages={pipeline_stages}')
        if period_relaxation_percent != None:
            command.append(f'--period_relaxation_percent={period_relaxation_percent}')

    command_str = " ".join(command)
    try:
        result = subprocess.run(command_str, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        result.check_returncode()  # This will raise an exception if the return code was non-zero
        print("Scheduling Run Done.")
    except subprocess.CalledProcessError as e:
        print("An error occurred while running the command.")
        print("Error message:", e.stderr)
    return

def get_graph_with_scheduling(output_schedule_ir_path, schedule_result_path):
    print("Reading Scheduling Result")
    G_schedule, register_count = ReadScheduleIR(output_schedule_ir_path)
    schedule_dict, max_stage_latency = read_SDC_pipeline_result(schedule_result_path)
#     print(schedule_dict)
    G_schedule, stage_num = register_SDC_result(G_schedule, schedule_dict)
    print("Result Graph Generation Done")
    return G_schedule, max_stage_latency, stage_num, register_count

In [7]:
# Define the operations as named tuples
add = namedtuple("add", "x y")
sub = namedtuple("sub", "x y")
mul = namedtuple("mul", "x y")
div = namedtuple("div", "x y")
shll = namedtuple("shll", "x y")  # Bitshift left
shrl = namedtuple("shrl", "x y")  # Bitshift right
not_ = namedtuple("not_", "x")    # Bitwise NOT
neg = namedtuple("neg", "x")      # Negation

# Evaluation function
def eval_ops(car, cdr):
    try:
        if isinstance(car, (int, float)):
            return car

        if len(cdr) == 0:
            return None

        op = car
        args = cdr

        a = args[0]
        if op == not_:
            return ~a
        if op == neg:
            return -a

        b = args[1]
        if op == add:
            return a + b
        if op == sub:
            return a - b
        if op == mul:
            return a * b
        if op == div and b != 0:
            return a / b
        if op == shll:
            return a << b
        if op == shrl:
            return a >> b

    except:
        pass

    return None

# Rewrite rules
a, b, c = vars("a b c")  # Variables for rewrite rules

list_rules: List[List[Any]] = [
    # DAC paper rules
    ["mul_associativity", mul(mul(a, b), c), mul(b, mul(a, c))],
    ["add_associativity", add(add(a, b), c), add(b, add(a, c))],
    ["mul_distributivity", mul(a, add(b, c)), add(mul(a, b), mul(a, c))],
    ["sum_same", add(a, a), mul(a, 2)],
    ["mul_sum_same", add(mul(a, b), b), mul(add(a, 1), b)],
    ["sub_to_neg", sub(a, b), add(a, neg(b))],
    ["mul_by_two", mul(a, 2), shll(a, 1)],
    ["merge_left_shift", shll(shll(a, b), c), shll(a, add(b, c))],
    ["merge_right_shift", shrl(shrl(a, b), c), shrl(a, add(b, c))],
    ["neg_to_not", neg(a), add(not_(a), 1)],
    
    # Basic arithmetic rules
    ["comm_add", add(a, b), add(b, a)],
    ["comm_mul", mul(a, b), mul(b, a)],
    ["zero_add", add(a, 0), a],
    ["zero_mul", mul(a, 0), 0],
    ["one_mul", mul(a, 1), a],
    ["sub_zero", sub(a, 0), a],
    ["neg", neg(a), sub(0, a)],

    # Bitwise rules
    ["shll_zero", shll(a, 0), a],
    ["shrl_zero", shrl(a, 0), a],
    ["not_not", not_(not_(a)), a],
    
    
]

# Convert list rules into rewrites
rules = [Rewrite(frm, to, name) for name, frm, to in list_rules]

# Function to check if two expressions are equivalent
def is_equal(expr_a, expr_b, iters=5):
    egraph = EGraph(eval_ops)
    id_a = egraph.add(expr_a)
    id_b = egraph.add(expr_b)
    egraph.run(rules, iters)
    return egraph.equiv(id_a, id_b)

# def get_egraph_in_range(expr, min_iters=0, max_iters=1000):
#     expr_list = []
#     # Create an EGraph with the evaluation function
#     egraph = EGraph(eval_ops)

#     # Add the expression to the e-graph
#     expr_id = egraph.add(expr)
       
#     # Run the e-graph with the rewrite rules     
        
#     for cur_iters in range(min_iters, max_iters):
#         egraph.run(rules, cur_iters)
#         cur_expr = egraph.extract(expr_id)
#         if cur_expr in expr_list:
#             pass
#         else:
#             expr_id = egraph.add(cur_expr)
#             expr_list.append(cur_expr)
        
#     return expr_list
egraph = EGraph(eval_ops)
def get_egraph_in_range(expr, max_iters=5):
    expr_list = []
    expr_id = egraph.add(expr)

    for i in range(max_iters):
        egraph.run(rules, 1)  # Run for one iteration
        cur_expr = egraph.extract(expr_id)
        
        if cur_expr not in expr_list:
            expr_list.append(cur_expr)
            expr_id = egraph.add(cur_expr)  # Add current expression for further rewrites

    return expr_list
a = "a"

def test_sum_same():
    expr_a = add(a, a)
    expr_b = mul(a, 2)
    assert is_equal(expr_a, expr_b)

test_sum_same()
expr_text = add(a, a)
print(get_egraph_in_range(expr_text))

[mul(x='a', y=2)]


In [8]:
# gen_json_from_ir("/home/miao/xls/work_space/EGraphUnitTest/all_unit_test.ir", "/home/miao/xls/work_space/EGraphUnitTest/all_unit_test.json")

In [9]:
def remove_entries(operand_list, to_remove):
    return [op for op in operand_list if op not in to_remove]

def traverse_graph(G, start_node, processed_nodes):
# This function takes in a graph and partitions it to the largest legal rewrite batch
# If a node is not in the allowed list, it is discarded.
    allowed_operations = ["add", "umul", "smul", "shll", "shrl", "not", "neg", "sub", "udiv", "sdiv"] # "literal",
    # Initialize the batch list
    batch = []

    # Create a set to keep track of visited nodes
    visited = set()

    # Traverse the graph
    stack = [start_node]
    while stack:
#         print(stack)
        current = stack.pop()
        current_op_type = G.nodes[current]['OperationType']

        # Check if the current node's operation is allowed
        if current_op_type in allowed_operations:
            # Add the current node to the batch and mark as visited
            batch.append(current)
            visited.add(current)

            for parent in G.predecessors(current):
                # Check if all children of the parent are either visited or disallowed
                all_children_allowed = all(
                    child in visited # or G.nodes[child].get('OperationType', None) not in allowed_operations
                    for child in G.successors(parent)
                )

                if all_children_allowed and parent not in processed_nodes and parent not in visited:
                    stack.append(parent)
    if len(batch) == 1 and G.nodes[batch[0]]['OperationType'] == 'literal':
        batch = []
    return batch

def get_expr_from_batch(G, batch):
    N = len(batch)
    max_loop_count = N * (N + 1) // 2
    batch_queue = deque(batch)
    visited_dict = {}
    loop_count = 0
    cur_expr = None
    is_unsigned = True
    while batch_queue:
        cur_expr = None
        cur_node = batch_queue.pop()
        # If all operands are not in the batch waiting the be handled, or they have been vistied before
        if all(operand not in batch or operand in visited_dict for operand in G.nodes[cur_node]['Operands']):
            if G.nodes[cur_node]['OperationType'] == 'literal':
                visited_dict[G.nodes[cur_node]['OperationName']] = str(G.nodes[cur_node]['Value'])
            else:
                cur_expr = G.nodes[cur_node]['OperationType'] + "("
                for operand in G.nodes[cur_node]['Operands']:
                    # If the operand belongs to a processed expr, use it. Otherwise use the operand as string
                    if operand in visited_dict.keys():
                        cur_expr = cur_expr + visited_dict[operand] + ","
                    else:
                        cur_expr = cur_expr + "'" + operand + "'" + ","
                cur_expr = cur_expr[:-1] + ")"
                visited_dict[G.nodes[cur_node]['OperationName']] = cur_expr
            
            if G.nodes[cur_node]['OperationType'] in ["neg", "smul", "sdiv"]:
                is_unsigned = False
            
        else:
            batch_queue.appendleft(cur_node)
            
        if loop_count > max_loop_count:
            break
        else:
            loop_count += 1
    
    # Handle sign mapping, handle not mapping
    cur_expr = cur_expr.replace("umul", "mul").replace("smul", "mul")
    cur_expr = cur_expr.replace("udiv", "div").replace("sdiv", "div")
    cur_expr = cur_expr.replace("not", "not_")
    if len(batch_queue) == 0: 
        # Return the one with the most dependency
        return cur_expr, is_unsigned
    else:
        print("[ERROR] Unresolved dependency in batch to expr generation")
        return None, None
    
    return None, None

def gen_json_from_expr_recursive(G, expr, operation_mapping, bit_width, nodes_involved_dict=None, counter=None):
    if counter is None:
        counter = 0
    if nodes_involved_dict is None:
        nodes_involved_dict = {}
    
    match = re.search(r'\b(\w+)\(([^()]+)\)', expr)
    if not match:
        return nodes_involved_dict, expr

    operation_type, operands = match.groups()
    operation_name = 'auto_gen'+str(counter)

    # Replace the innermost expression with the placeholder
    new_expr = expr.replace(match.group(0), operation_name)
#     print("OpType: ", operation_type, "; Operands: ",operands)
#     print("Iter:", counter, "expr: ",new_expr)
    
    
    # Handle each node:
    nodes_involved_dict, counter = handle_node_generation(G, operation_type, operands, operation_name, bit_width, nodes_involved_dict, operation_mapping, counter)
    return gen_json_from_expr_recursive(G, new_expr, operation_mapping, bit_width, nodes_involved_dict, counter+1)


# Node(name=None, bitwidth=None, operation=None, operands=None, 
#                  idNum=None, value=None, pos=None, FuncIO=None, start=None, width=None)

def handle_node_generation(G, operation_type, operands, operation_name, bit_width, nodes_involved_dict, operation_mapping, counter):
    operands_list = operands.split(',')
    new_node_operands = []
    NewNode = None
    for cur_operand in operands_list:
        cur_operand = cur_operand.split("=")[-1]
        # First setup operand list
        if cur_operand.strip("'").strip(" ").isdigit():
            counter += 1
            # Handle pure digit, generate literal from them
            try:
                cur_value = int(cur_operand.strip("'").strip(" "))
            except ValueError:
                cur_value = float(cur_operand.strip("'").strip(" "))
                
            # Update the bitwidth to the nearest power of 2 if necessary
            required_bitwidth = ceil(log2(abs(cur_value) + 1)) if cur_value >= 0 else ceil(log2(abs(cur_value)))
#             print(cur_value, ' ', bit_width, " " , required_bitwidth)

            if required_bitwidth > bit_width:
                bit_width = 2 ** ceil(log2(required_bitwidth))    
            NewLiteralNode = node_to_dict(Node('auto_gen'+str(counter), bit_width, 'Literal', value=cur_value))
            NewLiteralNode["ReplaceSelfWith"] = 'Gen'       
            nodes_involved_dict['auto_gen'+str(counter)] = NewLiteralNode
            new_node_operands.append('auto_gen'+str(counter))
        else:
            new_node_operands.append(cur_operand.strip("'"))
    #Now generate new node:
    NewNode = node_to_dict(Node(operation_name, bit_width, operation_mapping[operation_type], new_node_operands))
    NewNode["ReplaceSelfWith"] = 'Gen'
    nodes_involved_dict[operation_name] = NewNode
    return nodes_involved_dict, counter


In [10]:
def gen_json_from_ir(ir_input_path, json_output_path, rewrite_rule = "NoRule"):
    # This function takes in ir file and generate json instruction based on e-graph
    
    TopFunctionName = None
    with open(ir_input_path, 'r') as f:
        ir_content = f.read()
        ir_dict = slice_ir_by_function(ir_content)

    #     print(ir_content)

    FuncNodeDict = {}
    for fn_name, fn_content in ir_dict.items():
        NodeDict = {}
        Lines = ParseIRFile(fn_content)
        for Line in Lines:    
            NodeList = LineToDict(Line)
            if NodeList != None:
                for NodeObj in NodeList:
                    NodeDict[NodeObj.name] = NodeObj
            if "top" in Line:
                TopFunctionName = fn_name
        FuncNodeDict[fn_name] = NodeDict
    SubCounter = 0
    NodeCounter = 0
    JsonOutDict = {}
#     print(FuncNodeDict)


    ## The following code reads digraph and partitions them into legal rewrite batches
    graph_by_function = {}
    sorted_nodes_by_function = {}

    for fn_name, fn_nodes in FuncNodeDict.items():
        G = nx.DiGraph()
        G = DictToGraph(G, fn_nodes)
        graph_by_function[fn_name] = G
        
#         nx.draw(G, with_labels=True, font_weight='bold')

        # Check if the graph is a DAG
        if nx.is_directed_acyclic_graph(G):
            sorted_nodes = list(nx.topological_sort(G))
            sorted_nodes_by_function[fn_name] = sorted_nodes[::-1]
        else:
            sorted_nodes_by_function[fn_name] = None
            print(f"[Warning] Graph for function {fn_name} is not a DAG. Topological sorting cannot be performed.")

#         print(f"Function: {fn_name}")
#         print("Batches of nodes:", sorted_nodes_by_function[fn_name])   


    rewrite_batch_by_function = {}
    processed_nodes = set()

    for fn_name, nodes in sorted_nodes_by_function.items():
        G = graph_by_function[fn_name]
        rewrite_batch_by_function[fn_name] = []

        temp_nodes = nodes.copy()
        cur_index = 0
        while temp_nodes and cur_index < len(temp_nodes):
            start_node = temp_nodes[cur_index]
            cur_batch = traverse_graph(G, start_node, processed_nodes)
#             print("Current  Batch: ", cur_batch)
            if cur_batch != []:
                rewrite_batch_by_function[fn_name].append(cur_batch)
                for node in cur_batch:
                    processed_nodes.add(node)
                    if node in temp_nodes:
                        temp_nodes.remove(node)
            else:
                cur_index += 1

#         print(f"Function: {fn_name}")
#         print("Batches of nodes:", rewrite_batch_by_function[fn_name])

    ## The following code takes each batch, and generate expression for e-graph search
    rewrite_candidate_by_function = {}
    for fn_name, batch_list in rewrite_batch_by_function.items():
        G = graph_by_function[fn_name]
        rewrite_candidate_by_function[fn_name] = {}
        rewrite_candidate_by_function[fn_name]['is_unsigned'] = []
        rewrite_candidate_by_function[fn_name]['rewrite_candidates'] = []
        for batch in batch_list:
            batch_expr, is_unsigned = get_expr_from_batch(G, batch)
            print("FunctionName: ", fn_name)
            print("Input Expr: ", batch_expr)
            cur_candidates = get_egraph_in_range(eval(batch_expr))
            rewrite_candidate_by_function[fn_name]['is_unsigned'].append(is_unsigned)
            rewrite_candidate_by_function[fn_name]['rewrite_candidates'] .append(cur_candidates)
            print(cur_candidates)
    # print(rewrite_candidate_by_function)

    operation_mapping_unsigned = {"add" : "kAdd",
                                  "mul" : "kUMul",
                                  "div" : "kUDiv",
                                  "sub" : "kSub",
                                "literal" : "Literal",
                                 "neg" : "kNeg",
                                 "shll" : "kShll",
                                 "shrl" : "kShrl",
                                 "not_" : "kNot"}

    operation_mapping_signed = {"add" : "kAdd",
                                  "mul" : "kSMul",
                                  "div" : "kSDiv",
                                  "sub" : "kSub",
                                "literal" : "Literal",
                                 "neg" : "kNeg",
                                 "shll" : "kShll",
                                 "shrl" : "kShrl",
                                 "not_" : "kNot"}
    json_dict = {}
    counter = 0
    for fn_name, all_rewrite_candidates_dict in rewrite_candidate_by_function.items():
        G = graph_by_function[fn_name]
        for i in range(len(all_rewrite_candidates_dict['rewrite_candidates'])):
            cur_rewrite_candidate = all_rewrite_candidates_dict['rewrite_candidates'][i]
            
            #First we initialize the dictionary
            json_dict[str(counter)] = {}
            json_dict[str(counter)]["FuncName"] = fn_name

            #Then, find bitwidth and if the variables are signed, and assign operator mapping
            cur_is_unsigned = all_rewrite_candidates_dict['is_unsigned'][i]
            bit_width = 0
            for cur_node_name in rewrite_batch_by_function[fn_name][i]:
                cur_bit_width = G.nodes[cur_node_name]['BitWidth']
                if cur_bit_width > bit_width:
                    bit_width = cur_bit_width
                    
            if cur_is_unsigned:
                cur_operation_mapping = operation_mapping_unsigned
            else:
                cur_operation_mapping = operation_mapping_signed
                
                
            # Now we call our selector to get a best rewrite
            if rewrite_rule == "NoRule":
                rewrite_expr = cur_rewrite_candidate[-1]
            elif rewrite_rule == "NaivePick":
                rewrite_expr = naive_pick(G, cur_rewrite_candidate, bit_width, cur_operation_mapping)
            else:
                rewrite_expr = cur_rewrite_candidate[-1]  
            
            # print("FunctionName: ",fn_name)
            print("Rewrite Expr", rewrite_expr)

            #Next, handle node generation
            nodes_involved_dict, out_node_name = gen_json_from_expr_recursive(G, str(rewrite_expr), cur_operation_mapping, bit_width)

            #Next, handle node elimination
            for old_node_id in rewrite_batch_by_function[fn_name][i]:
                nodes_involved_dict[G.nodes[old_node_id]['OperationName']] = G.nodes[old_node_id]
                nodes_involved_dict[G.nodes[old_node_id]['OperationName']]['ReplaceSelfWith'] = 'Kill'

            #Next, handle batch output replacement
            #ToDo: we need to verify the first is always the output of the batch
            nodes_involved_dict[G.nodes[rewrite_batch_by_function[fn_name][i][0]]['OperationName']] = G.nodes[rewrite_batch_by_function[fn_name][i][0]]
            nodes_involved_dict[G.nodes[rewrite_batch_by_function[fn_name][i][0]]['OperationName']]['ReplaceSelfWith'] = str(out_node_name)

            json_dict[str(counter)]["NodesInvolved"] = list(nodes_involved_dict.values())
            counter += 1
            
#             print(nodes_involved_dict)

    with open(json_output_path, 'w') as json_file:
            json.dump(json_dict, json_file, indent=4)  
    return G, TopFunctionName

In [11]:
def egraph_flow(ir_input_path, delay_model='sky130', 
               Schedule_Method="SDC", selector="NoRule", clock_period_ps=None, clock_margin_precent=None, 
               pipeline_stages=None, period_relaxation_percent=None):
    
    command_executable = '/home/miao/xls/bazel-bin/xls/tools/RL_main'
    ir_unify_name_out = ir_input_path.replace('.ir','.unify.ir')
    json_output_path = ir_input_path.replace('.ir','.json')
    ir_output_path = ir_input_path.replace('.ir','_substitution.ir')
    schedule_result_path = ir_input_path.replace('.ir', '_schedule.txt')
    output_verilog_path = json_output_path.replace(".json", ".v")
    output_schedule_ir_path = ir_output_path.replace(".ir", "_schedule.ir")
    
    run_unify_name(ir_input_path, ir_unify_name_out)
    #Get rewrite json file from IR and selector model
    G, TopFunctionName = gen_json_from_ir(ir_unify_name_out, json_output_path, rewrite_rule=selector)
    #Run cc rewriter to implement the rewrite
    run_rewriter(command_executable, ir_unify_name_out, json_output_path, ir_output_path)
    #Run SDC with the new IR and generate result.
    run_sdc_scheduler(ir_output_path, output_verilog_path, schedule_result_path, output_schedule_ir_path, 
                      delay_model, TopFunctionName, clock_period_ps, clock_margin_precent, 
                      pipeline_stages, period_relaxation_percent)
    #Collect the scheduling result to a graph
    G_schedule, max_stage_latency, stage_num, register_count = get_graph_with_scheduling(output_schedule_ir_path, schedule_result_path)
    
    return G_schedule, max_stage_latency, stage_num, register_count
    

In [12]:
# gen_json_from_ir("/home/miao/xls/work_space/EGraphUnitTest/all_unit_test.ir", "/home/miao/xls/work_space/EGraphUnitTest/all_unit_test.json")

In [13]:
# egraph_flow("/home/miao/xls/work_space/EGraphTest/test.opt.ir", clock_period_ps = 1000)
egraph_flow("/home/miao/xls/work_space/EGraphUnitTest/all_unit_test.ir", clock_period_ps = 1000)
# egraph_flow("/home/miao/xls/work_space/Sha256/sha256.opt.ir", clock_period_ps = 3000, selector="NaivePick")
# egraph_flow("/home/miao/xls/work_space/adler32/adler32.opt.ir", clock_period_ps = 25825)

Name Unification Run Done.
FunctionName:  AddAssociativity
Input Expr:  add(add('z1','x1'),'y1')
[add(x='x1', y=add(x='z1', y='y1')), add(x='z1', y=add(x='x1', y='y1'))]
FunctionName:  Unsigned_MulAssociativity
Input Expr:  mul(mul('z2','x2'),'y2')
[mul(x='x2', y=mul(x='z2', y='y2')), mul(x='z2', y=mul(x='x2', y='y2'))]
FunctionName:  MulDistribution
Input Expr:  mul(add('z3','x3'),'y3')
[mul(x=add(x='z3', y='x3'), y='y3')]
FunctionName:  SumSame
Input Expr:  add('x4','x4')
[mul(x='x4', y=2), add(x='x4', y='x4')]
FunctionName:  MulSumSame
Input Expr:  add(mul('x5','y5'),'y5')
[mul(x=add(x='x5', y=1), y='y5'), mul(x='y5', y=add(x=1, y='x5'))]
FunctionName:  MulSumSame
Input Expr:  add(mul('x5','y5'),'z5')
[add(x=mul(x='x5', y='y5'), y='z5'), add(x=mul(x='y5', y='x5'), y='z5')]
FunctionName:  Signed_MulAssociativity
Input Expr:  mul(mul('z6','x6'),'y6')
[mul(x='x6', y=mul(x='z6', y='y6'))]
FunctionName:  SubToNeg
Input Expr:  sub('x7','y7')
[sub(x='x7', y='y7')]
FunctionName:  MulByTwo
I

(<networkx.classes.digraph.DiGraph at 0x7fdcdc3e9510>, 853, 2, 6)

In [14]:
path = "/home/miao/xls/work_space/Sha256/sha256.opt.ir"
# path = "/home/miao/xls/work_space/MaxTimeTest/test.opt.ir"
G, max_stage_latency, stage_num, register_count = get_sdc_from_graph(path, clock_period_ps=4000)
print("Number Of Nodes Before: ", len(G.nodes()), "Reg count: ", register_count)
print("Max Pipeline Latency Before: ", max_stage_latency, "stages:" , stage_num)
G, max_stage_latency, stage_num, register_count = egraph_flow(path, clock_period_ps = 4000, selector="NaivePick")
print("Number Of Nodes After: ", len(G.nodes()), "Reg count: ", register_count)
print("Max Pipeline Latency After: ", max_stage_latency, "stages:" , stage_num)


NameError: name 'get_sdc_from_graph' is not defined

Rewriter Run Done
Scheduling Run Done.
Reading Scheduling Result
Result Graph Generation Done
Number Of Nodes After:  4929 Reg count:  1235
Max Pipeline Latency After:  3960 stages: 56


In [None]:
# The subsequent parts are for generating graphs for each rewrite expressions,
# as well as testing ways to use these graphs

In [None]:
# These are helper functions for selectors
def expr_to_graph(G, expr, bit_width, cur_operation_mapping):
    cur_graph = nx.DiGraph()
    nodes_involved_dict, out_node_name = gen_json_from_expr_recursive(G, str(expr), cur_operation_mapping, bit_width)
    filtered_dict = {k: v for k, v in nodes_involved_dict.items() if v['ReplaceSelfWith'] == 'Gen'}
    node_obj_dict = dict_to_nodes(filtered_dict)
#     print(node_obj_dict, "\n")
    node_name_list = list(node_obj_dict.keys())
    new_nodes = {}
    # Iterate over the nodes and check for missing operands
    for node_name, node_obj in node_obj_dict.items():
        if node_obj.operands != None:
            for operand in node_obj.operands:
                if operand not in node_name_list and operand not in new_nodes:
                    # Create a new Node object for the missing operand
                    new_nodes[operand] = Node(name=operand, operands=[])
        else:
            node_obj.operands = []
        
    # Merge the new nodes into the main dictionary
    node_obj_dict.update(new_nodes)
    counter = 0
    for node_name, node_obj in node_obj_dict.items():
        node_obj.idNum = counter
        counter += 1
    cur_graph = DictToGraph(cur_graph, node_obj_dict)
#     print("expr: ", expr)
#     print("graph", cur_graph.nodes(), "\n")
    return cur_graph

def dict_to_nodes(node_dict):
    inversed_mapping = {"kAdd" : "add",
                          "kUMul" : "mul",
                        "kSMul" : "mul",
                          "kUDiv" : "div",
                        "kSDiv" : "div",
                          "kSub" : "sub",
                        "Literal" : "literal",
                         "kNeg" : "neg",
                         "kShll" : "shll",
                         "kShrl" : "shrl",
                         "kNot" : "not_",
                       "kConcat" : "concat",
                       "bit_slice" : "bitslice"}
    node_objects = {}    
    for node_name, node_data in node_dict.items():
        if node_data.get("OperationType") in inversed_mapping.keys():
            mapped_operation = inversed_mapping[node_data.get("OperationType")]
        else:
            mapped_operation = node_data.get("OperationType")
        # Create Node object using the data from the dictionary
        node = Node(
            name=node_data.get("OperationName"),
            bitwidth=node_data.get("BitWidth"),
            operation=mapped_operation,
            operands=node_data.get("Operands"),
            idNum=node_data.get("idNum"),
            value=node_data.get("Value"),
            pos=node_data.get("Pos"),
            FuncIO=node_data.get("FuncIO"),
            start=node_data.get("Start"),
            width=node_data.get("Width"),
            array_sizes=node_data.get("ArraySize"),
            indices=node_data.get("Indices")
        )
        # Store the Node object using the node name as the key
        node_objects[node_name] = node
    return node_objects


In [None]:
def get_graph_with_smallest_longest_path(graph_list):
    min_longest_path_length = float('inf')
    index_of_min = -1

    for i, graph in enumerate(graph_list):
        try:
            # Assuming the graph is a DAG
            longest_path_length = len(nx.dag_longest_path(graph)) - 1  # Subtract 1 to get the number of edges
            if longest_path_length < min_longest_path_length:
                min_longest_path_length = longest_path_length
                index_of_min = i
            elif longest_path_length == min_longest_path_length:
                # Update the index if it's larger in case of tie
                index_of_min = max(index_of_min, i)
        except nx.NetworkXUnfeasible:
            # This exception is raised if the graph is not a DAG
            print(f"Graph at index {i} is not a DAG.")

    return index_of_min

def get_graph_with_fewest_nodes(graph_list):
    min_node_count = float('inf')
    index_of_min = -1

    for i, graph in enumerate(graph_list):
        node_count = len(graph.nodes)

        if node_count < min_node_count:
            min_node_count = node_count
            index_of_min = i
        elif node_count == min_node_count:
            # Update the index if it's larger in case of tie
            index_of_min = max(index_of_min, i)
#         print(node_count)
#     print("Chose: ", index_of_min)
    return index_of_min


def naive_pick(G, cur_rewrite_candidate, bit_width, cur_operation_mapping):
    graph_list = []
    for cur_expr in cur_rewrite_candidate:
        cur_graph = expr_to_graph(G, cur_expr,  bit_width, cur_operation_mapping)
        graph_list.append(cur_graph)  
#         print(cur_expr)
#         print(cur_graph.nodes)

#     native_pick_index = get_graph_with_smallest_longest_path(graph_list)
    native_pick_index = get_graph_with_fewest_nodes(graph_list)
    best_rewrite = cur_rewrite_candidate[native_pick_index]
#     print(native_pick_index)
#     print("\n")
    return best_rewrite

In [None]:
G, FuncName = gen_json_from_ir("/home/miao/xls/work_space/Sha256/sha256.opt.ir", "/home/miao/xls/work_space/Sha256/sha256.json", rewrite_rule="NativePick")
# nx.draw(G, with_labels=True, font_weight='bold')
# G, FuncName = gen_json_from_ir("/home/miao/xls/work_space/EGraphUnitTest/all_unit_test.ir", "/home/miao/xls/work_space/EGraphUnitTest/all_unit_test.json", rewrite_rule="NativePick")

In [None]:
# The subsequent part is for getting SDC results and testing GNN

In [None]:
def get_sdc_from_graph(ir_input_path, delay_model='sky130', 
               Schedule_Method="SDC", clock_period_ps=None, clock_margin_precent=None, 
               pipeline_stages=None, period_relaxation_percent=None):
    
    command_executable = '/home/miao/xls/bazel-bin/xls/tools/RL_main'
    ir_unify_name_out = ir_input_path.replace('.ir','.unify.ir')
    json_output_path = ir_input_path.replace('.ir','.json')
    ir_output_path = ir_input_path.replace('.ir','_substitution.ir')
    schedule_result_path = ir_input_path.replace('.ir', '_schedule.txt')
    output_verilog_path = json_output_path.replace(".json", ".v")
    output_schedule_ir_path = ir_output_path.replace(".ir", "_schedule.ir")
    
    run_unify_name(ir_input_path, ir_unify_name_out)
    #Get rewrite json file from IR and selector model
    G, TopFunctionName = gen_json_from_ir(ir_unify_name_out, json_output_path)
    #Run cc rewriter to implement the rewrite
#     run_rewriter(command_executable, ir_unify_name_out, json_output_path, ir_output_path)
    #Run SDC with the new IR and generate result.
    run_sdc_scheduler(ir_unify_name_out, output_verilog_path, schedule_result_path, output_schedule_ir_path, 
                      delay_model, TopFunctionName, clock_period_ps, clock_margin_precent, 
                      pipeline_stages, period_relaxation_percent)
    #Collect the scheduling result to a graph
    G_schedule, max_stage_latency, stage_num, register_count = get_graph_with_scheduling(output_schedule_ir_path, schedule_result_path)
    
    return G_schedule, max_stage_latency, stage_num, register_count

G, max_stage_latency, stage_num, register_count = get_sdc_from_graph("/home/miao/xls/work_space/Sha256/sha256.opt.ir", clock_period_ps=1000)
print("Max Pipeline Latency: ", max_stage_latency)

In [None]:
# List of allowed operations
mapping_operations = ["add", "umul", "smul", "literal", "shll", "shrl", "not", 
                      "neg", "sub", "udiv", "sdiv", "bit_slice", "concat", "xor", "and", 
                      "nor", "tuple",None]

# Create a mapping for one-hot encoding
op_mapping = {op: i for i, op in enumerate(mapping_operations)}
num_ops = len(mapping_operations)

# Create a mapping of nodes to integers
node_mapping = {node: i for i, node in enumerate(G.nodes())}
num_nodes = len(node_mapping)

# Convert edges using the mapping
edges = torch.tensor([[node_mapping[u], node_mapping[v]] for u, v in G.edges()], dtype=torch.long).t().contiguous()

# Convert node attributes and include 'node_delay_ps', 'path_delay_ps', and one-hot encoded 'OperationType'
x = []
for node in G.nodes():
    node_features = [
        G.nodes[node]['node_delay_ps'],
        G.nodes[node]['path_delay_ps']
    ]
    op_type = G.nodes[node].get('OperationType')
    op_vector = [0] * num_ops  # Initialize with zeros
    if op_type is not None:
        op_vector[op_mapping[op_type]] = 1  # Set the corresponding index to 1 for one-hot encoding
    node_features.extend(op_vector)
    x.append(node_features)

x = torch.tensor(x, dtype=torch.float)

# Create PyTorch Geometric data
data = Data(x=x, edge_index=edges)

# Define the GNN Model
class GNN(torch.nn.Module):
    def __init__(self, n):
        super(GNN, self).__init__()
        # Adjust the input size to accommodate the operation type features
        self.conv1 = GCNConv(data.num_node_features, 16)
        self.conv2 = GCNConv(16, 32)
        self.linear = torch.nn.Linear(32, n)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x, edge_index)
        x = self.linear(x)
        return x

# Instantiate the model (define 'n' as needed)
n = 10  # Example value for 'n'
model = GNN(n)

# Example forward pass
out = model(data)
print(out)  # Should be (num_nodes, n)