# modularization based on minimum latency

## Overall flow

![general_flow](general_flow.png) 

## Environment

In [None]:
from qonnx.core.modelwrapper import ModelWrapper
import pandas as pd
import numpy as np
import configparser
import json

from prepare_model import prepareModel
from com_latency_estimate import com_latency
from generate_report import generateReport
from cut import generate_block
from fussion_parallelize import fussionParallelize

build_dir = "../../../../notebooks/GitHub/M_Project/Patitioning/tmp" 

## Model and Parallelization

In [None]:
model_name = "cnv_e50_1bit_trained"
model_dir = "../../../../notebooks/GitHub/M_Project/Model/tmp"
model_finn_path = prepareModel(model_name, build_dir, model_dir)
model_folded_path = fussionParallelize(build_dir, os.path.basename(model_finn_path)) # include optimization and folding configuration

## Prtitioning Configuration

### Constarint

In [None]:
config = configparser.ConfigParser()
config.read('device_profiles.ini')

maxNumberOfDevices = int(config["GLOBAL"]["maxNumberOfDevices"])
frequency = int(config["GLOBAL"]["frequency"])

devices = [];
for d in range(maxNumberOfDevices):
    device = {}
    device["maxLUT"] = int(config[f"Device_{d}"]["maxLUT"])
    device["maxBRAM"] = int(config[f"Device_{d}"]["maxBRAM"])
    #device["maxCommunication"] = int(config[f"Device_{d}"]["maxCommunication"])
    devices.append(device)

lib = {
"INT8": 8,
"BINARY": 1,
"UINT32": 32}

### Design Space

In [None]:
def estimate_communication_cost(output_shapes, frequency=1330000000000): # maximum KV@^) board frequency
    cost = com_latency(output_shapes, frequency, batchsize=16)
    cost *= 10**9 # change to nano
    return cost

def sub_block_resource_precompute(number_of_nodes, node_lut, node_bram, node_latency):
    execution_latency = [[0] * number_of_nodes for _ in range(number_of_nodes)]
    lut_usage = [[0] * number_of_nodes for _ in range(number_of_nodes)]
    bram_usage = [[0] * number_of_nodes for _ in range(number_of_nodes)] 

    for start in range(number_of_nodes):
        lut_sum, bram_sum, exec_sum = 0, 0, 0
        for end in range(start, number_of_nodes):            
            lut_sum += node_lut[end]
            bram_sum += node_bram[end]
            exec_sum += node_latency[end]

            execution_latency[start][end] = exec_sum
            lut_usage[start][end] = lut_sum
            bram_usage[start][end] = bram_sum
    return execution_latency, lut_usage, bram_usage

def get_nodes_information(number_of_nodes, model_folded_path, target_clk_ns):
    node_latency = []  
    node_lut = []
    node_bram = []
    output_shapes = []
    output_sizes = []
    output_dtypes = []

    for n in range(number_of_nodes):
        start_node = n
        end_node = start_node + 1
        path_submodel, output_shape, output_dtype, input_shape, input_dtype = generate_block(model_folded_path, start_node, end_node)
        estimate_layer_resources, estimate_network_performance = generateReport(ModelWrapper(path_submodel), target_clk_ns)
        output_element = 1
        for size in output_shape:
            output_element = output_element*size
        node_latency.append(estimate_network_performance["estimated_latency_ns"])
        node_lut.append(estimate_layer_resources["total"]["LUT"])
        node_bram.append(estimate_layer_resources["total"]["BRAM_18K"])
        output_shapes.append(output_shape)
        output_sizes.append(output_shape)#output_element)
        output_dtypes.append(str(output_dtype))
    return node_latency, node_lut, node_bram, output_shapes, output_sizes, output_dtypes
    
def generate_partitioning_configuration(num_devices, lut_usage, bram_usage, max_lut, max_bram, 
                    output_shapes, execution_latency, communication_latency_fn):
    num_nodes = len(lut_usage)

    # Helper function to compute block latency
    def compute_latency(start, end, device_idx):
        exec_latency = sum(execution_latency[start:end])
        if start > 0:  # Communication cost from previous block
            comm_latency = communication_latency_fn(output_shapes[start-1])
        else:
            comm_latency = 0
        return  comm_latency + exec_latency
        
    # Helper function to check resource constraints for a block
    def is_valid_block(start, end, device_idx):
        #print(f"device: {str(device_idx):<2}, start: {str(start):<3}, end: {str(end):<3}, total lut: {str(lut_usage[start:end])}")
        total_lut = sum(lut_usage[start:end])
        total_bram = sum(bram_usage[start:end])
        return total_lut <= max_lut[device_idx] and total_bram <= max_bram[device_idx]
    

   
    # Initialize results
    best_latency = float('inf')
    best_split = None
    best_devices = 0

    # Explore all possible splits for 1 to num_devices
    for d in range(3, num_devices):
        # DP table to store the minimum latency for splitting the first i nodes into j devices
        dp = [[float('inf')] * (d + 1) for _ in range(num_nodes + 1)]
        split = [[None] * (d + 1) for _ in range(num_nodes + 1)]
        dp[0][0] = 0  # Base case: 0 latency with 0 nodes and 0 devices

        for i in range(1, num_nodes + 1):  # For each node
            for j in range(1, d + 1):  # For each device
                for k in range(i):  # Split point
                    if is_valid_block(k, i, j - 1):  # Check resource constraints

                        if i == 0 :
                            latency = dp[k][j - 1] + compute_latency(k, i, j - 1)
                        else:
                            latency = dp[k-1][j - 1] + compute_latency(k, i, j - 1)  

                        ### debugging start
                        #if d == 3:
                        #    if k == 15:
                        #        print(f"dp[k][j - 1]: {str(dp[k][j - 1])}")
                        #        print(f"latency: {latency}")
                        ### debugging end    
                        if latency < dp[i][j]:
                            dp[i][j] = latency
                            split[i][j] = k

        # Check if this configuration gives better latency
        if dp[num_nodes][d] < best_latency:
            best_latency = dp[num_nodes][d]
            best_devices = d

            # Recover the split configuration
            best_split = []
            current_node = num_nodes
            for j in range(d, 0, -1):
                prev_node = split[current_node-1][j] 
                best_split.append(list(range(prev_node, current_node)))
                current_node = prev_node
            best_split.reverse()  # Reverse to get the correct order
        #####
        #df = pd.DataFrame(dp)
        #print(df)

    return best_devices, best_split, best_latency#, split

In [None]:
model = ModelWrapper(model_folded_path)
number_of_nodes = len(model.graph.node)
num_devices = len(devices)-1
max_lut=[]
max_bram=[]
for d in devices:
    max_lut.append(d["maxLUT"])
    max_bram.append(d["maxBRAM"])

node_latency, node_lut, node_bram, output_shapes, output_sizes, output_dtypes = get_nodes_information(number_of_nodes, 
                                                                                                      model_folded_path, 
                                                                                                      target_clk_ns=10)
best_devices, best_split,best_latency = generate_partitioning_configuration(num_devices=maxNumberOfDevices, 
                                           lut_usage=node_lut, 
                                           bram_usage=node_bram, 
                                           max_lut=max_lut, 
                                           max_bram=max_bram, 
                                           output_shapes=output_shapes, 
                                           execution_latency=node_latency, 
                                           communication_latency_fn=estimate_communication_cost)

In [None]:
for i, b in enumerate(best_split):
    print(f"module: {i:<3} start: {b[0]:<4} end: {b[-1]:<4}")

In [None]:
import math
print("output_shapes")
for index, output_shape in enumerate(output_shapes):
    total_elements = math.prod(output_shape)
    print(f"{index:<5} {str(output_shape):<20} {total_elements:<10}")

In [None]:
total_elements = 0
print("node_bram")
for index, nl in enumerate(node_bram):
    total_elements += nl
    print(f"{index:<5} {str(nl):<20} {total_elements:<10}")

## Partitioning

### generate sub model

In [None]:
optimal_assignment = []
for blck in best_split:
    optimal_assignment.append((blck[0],blck[-1]))
optimal_assignment

In [None]:
import onnx
from qonnx.transformation.general import GiveUniqueNodeNames
from qonnx.transformation.general import RemoveUnusedTensors

sub_modules_path = []

for block_cfg in optimal_assignment:
    path_submodel, output_shape, output_dtype, input_shape, input_dtype = generate_block(model_folded_path, block_cfg[0], block_cfg[1]+1)
    # change the ouput tensor
    model = ModelWrapper(path_submodel)
    model = model.transform(RemoveUnusedTensors())
    model = model.transform(GiveUniqueNodeNames())
    model.set_tensor_shape("global_out", output_shape, dtype=onnx.TensorProto.FLOAT)
    # change the input tensor
    model.set_tensor_shape(ModelWrapper(path_submodel).graph.node[0].input[0], input_shape, dtype=onnx.TensorProto.FLOAT)
    model = model.transform(RemoveUnusedTensors())
    model = model.transform(GiveUniqueNodeNames())
    
    model.save(path_submodel)
    sub_modules_path.append(path_submodel)
sub_modules_path;

## Synthesize

In [None]:
from finn.transformation.fpgadataflow.make_zynq_proj import ZynqBuild
from finn.transformation.fpgadataflow.make_pynq_driver import MakePYNQDriver
from qonnx.transformation.general import GiveUniqueNodeNames
from qonnx.transformation.general import RemoveUnusedTensors
from finn.util.basic import make_build_dir
from shutil import copy
from distutils.dir_util import copy_tree
import uuid

test_pynq_board = "KV260_SOM"
target_clk_ns = 10
deployment_directories = []
final_modules_directories = []

for block in sub_modules_path:
    module = ModelWrapper(block)
    module = module.transform(RemoveUnusedTensors())
    module = module.transform(GiveUniqueNodeNames())
    module = module.transform(ZynqBuild(platform = test_pynq_board, period_ns = target_clk_ns))
    module = module.transform(MakePYNQDriver("zynq-iodma"))

    deployment_dir = make_build_dir(prefix="pynq_deployment_")
    module.set_metadata_prop("pynq_deployment_dir", deployment_dir)
    deployment_directories.append(deployment_dir)

    bitfile = module.get_metadata_prop("bitfile")
    hwh_file = module.get_metadata_prop("hw_handoff")
    deploy_files = [bitfile, hwh_file]
    for dfile in deploy_files:
        if dfile is not None:
            copy(dfile, deployment_dir)
    pynq_driver_dir = module.get_metadata_prop("pynq_driver_dir")
    copy_tree(pynq_driver_dir, deployment_dir)

    final_modules_directory = build_dir + f"/{str(uuid.uuid4())}.onnx"
    module.save(final_modules_directory)
    final_modules_directories.append(final_modules_directory)

## Report

In [None]:
final_report = [ str(deployment_directories), str(final_modules_directories)]
with open('report.txt', 'w') as file:
    json.dump(final_report, file, indent=4)

In [None]:
import json

final_report = [str(sub_modules_report), str(deployment_directories), str(final_modules_directories)]
with open('example_report_summary.txt', 'w') as file:
    json.dump(final_report, file, indent=4)
print("example_report_summary.txt")

#### option 1

In [None]:
import matplotlib.pyplot as plt
import itertools

def plot_design_space_with_combinations(dp, split_point, number_of_nodes, num_devices):
    # Store each configuration of latency and device count
    device_counts = list(range(1, num_devices + 1))
    all_configurations = []
    
    # Collect all possible layer combinations and corresponding latencies
    for j in device_counts:
        # Collect all latencies and assignments for the j-device configuration
        config_latencies = []
        for i in range(1, number_of_nodes + 1):
            if dp[i][j] != float('inf'):  # Only plot valid configurations
                # Traverse split points to get all combinations
                assignment = []
                layer = i
                devices_left = j
                while devices_left > 0 and layer > 0:
                    start_layer = split_point[layer][devices_left]
                    assignment.append((start_layer, layer - 1))
                    layer = start_layer
                    devices_left -= 1
                assignment = assignment[::-1]  # Reverse to start from layer 0
                config_latencies.append((dp[i][j], assignment))
        
        # Store each unique latency and assignment per device count j
        all_configurations.append((j, config_latencies))

    # Plot each device count's configurations with different markers
    plt.figure(figsize=(12, 8))
    
    for j, configs in all_configurations:
        for latency, assignment in configs:
            # Scatter each latency point for the device count `j`
            plt.scatter(j, latency, label=f'Devices: {j}, Layers: {assignment}')
    
    # Find and annotate the optimal configuration
    optimal_latency = min(dp[number_of_nodes][j] for j in device_counts)
    optimal_devices = min(device_counts, key=lambda j: dp[number_of_nodes][j])
    
    plt.xlabel('Number of Devices')
    plt.ylabel('Total Latency')
    plt.title('Design Space: Latency for Layer Assignments Across Devices')
    plt.legend(loc='upper right', bbox_to_anchor=(1.4, 1), fontsize='small', title="Assignments")
    plt.grid(True)
    
    # Annotate optimal point
    plt.annotate(f'Optimal: {optimal_devices} devices\nLatency: {optimal_latency}',
                 xy=(optimal_devices, optimal_latency),
                 xytext=(optimal_devices + 0.3, optimal_latency + 0.05 * optimal_latency),
                 arrowprops=dict(facecolor='black', shrink=0.05))

    plt.show()

In [None]:
plot_design_space_with_combinations(dp,split_point, 19, 3)

#### option 2

In [None]:
import matplotlib.pyplot as plt

def plot_feasible_design_space(dp, split_point, number_of_nodes, num_devices, max_lut, max_bram, lut_usage, bram_usage):
    # Store configurations that meet all criteria
    device_counts = list(range(1, num_devices + 1))
    feasible_configurations = []

    # Collect feasible layer assignments and corresponding latencies
    for j in device_counts:
        # Only consider configurations that process all nodes
        if dp[number_of_nodes][j] < float('inf'):  # Feasible configuration check
            # Reconstruct layer assignment for this device configuration
            assignment = []
            layer = number_of_nodes
            devices_left = j
            valid_assignment = True
            
            while devices_left > 0 and layer > 0:
                start_layer = split_point[layer][devices_left]
                # Check if the configuration meets resource constraints
                if start_layer != -1 and lut_usage[start_layer][layer - 1] <= max_lut and bram_usage[start_layer][layer - 1] <= max_bram:
                    assignment.append((start_layer, layer - 1))
                    layer = start_layer
                    devices_left -= 1
                else:
                    valid_assignment = False
                    break  # Invalid configuration, skip
                
            # Only add valid, feasible assignments that cover all nodes
            if valid_assignment and layer == 0:
                assignment = assignment[::-1]  # Reverse to show assignment from layer 0
                feasible_configurations.append((j, dp[number_of_nodes][j], assignment))

    # Plot each feasible configuration with device counts on the x-axis and latency on the y-axis
    plt.figure(figsize=(12, 8))

    for devices, latency, assignment in feasible_configurations:
        plt.scatter(devices, latency, label=f'Devices: {devices}, Layers: {assignment}')

    # Identify the optimal feasible configuration
    optimal_config = min(feasible_configurations, key=lambda x: x[1])
    optimal_devices, optimal_latency, optimal_assignment = optimal_config

    # Plot annotations and details
    plt.xlabel('Number of Devices')
    plt.ylabel('Total Latency')
    plt.title('Feasible Design Space: Latency vs Number of Devices with Valid Layer Assignments')
    plt.legend(loc='upper right', bbox_to_anchor=(1.2, 1), fontsize='small', title="Assignments")
    plt.grid(True)
    
    # Annotate optimal point
    plt.annotate(f'Optimal: {optimal_devices} devices\nLatency: {optimal_latency}',
                 xy=(optimal_devices, optimal_latency),
                 xytext=(optimal_devices + 0.3, optimal_latency + 0.05 * optimal_latency),
                 arrowprops=dict(facecolor='black', shrink=0.05))

    plt.show()


In [None]:
plot_feasible_design_space(dp, split_point, 19, 3, 24000, 24000, lut_usage, bram_usage)

#### option 3

In [None]:
import unittest
from unittest.mock import MagicMock

class TestOptimalLatency(unittest.TestCase):

    def setUp(self):
        # Mock input data for testing
        self.number_of_nodes = 5
        self.num_devices = 3
        self.target_clk_ns = 10

        # Mock device constraints
        self.devices = [
            {"maxLUT": 1000, "maxBRAM": 500},
            {"maxLUT": 2000, "maxBRAM": 1000},
            {"maxLUT": 3000, "maxBRAM": 1500},
        ]

        # Mock node information
        self.node_latency = [100, 200, 150, 300, 250]
        self.node_lut = [500, 400, 300, 700, 600]
        self.node_bram = [200, 150, 100, 250, 200]
        self.output_sizes = [500, 400, 300, 700, 600]

        # Mock model_folded_path, just for placeholder
        self.model_folded_path = "mock_model_path"
        self.com_frequency = 1e8  # Example communication frequency

    def test_estimate_communication_cost(self):
        output_shapes = [10, 20, 30]
        output_dtypes = "float32"
        frequency = 1e8  # Example frequency

        # Assume 32 bits per float and 6000 elements
        expected_cost = (1 / frequency) * 32 * 6000 * 1e9
        cost, elements = estimate_communication_cost(output_shapes, output_dtypes, frequency)
        
        self.assertAlmostEqual(cost, expected_cost)
        self.assertEqual(elements, 6000)

    def test_get_nodes_information(self):
        # Mocking the required external calls within get_nodes_information
        ModelWrapper = MagicMock()
        ModelWrapper.return_value.graph.node = [None] * self.number_of_nodes

        generate_block = MagicMock()
        generate_block.return_value = ("mock_path", [10, 20], "float32", [10], "float32")

        generateReport = MagicMock()
        generateReport.return_value = (
            {"total": {"LUT": 300, "BRAM_18K": 100}},
            {"estimated_latency_ns": 150}
        )

        # Run the function
        latencies, luts, brams, shapes, sizes, dtypes = get_nodes_information(
            self.number_of_nodes, self.model_folded_path, self.target_clk_ns
        )

        # Assert values are returned correctly
        self.assertEqual(len(latencies), self.number_of_nodes)
        self.assertEqual(len(luts), self.number_of_nodes)
        self.assertEqual(len(brams), self.number_of_nodes)

    def test_prepare_state_space(self):
        dp, split_point = prepare_state_space(self.num_devices, self.number_of_nodes)
        
        # Check dp and split_point structure
        self.assertEqual(len(dp), self.number_of_nodes + 1)
        self.assertEqual(len(dp[0]), self.num_devices + 1)
        self.assertEqual(dp[0][1], 0)  # Initial state set to zero

    def test_dynamic_programming(self):
        dp, split_point = prepare_state_space(self.num_devices, self.number_of_nodes)
        execution_latency, lut_usage, bram_usage = sub_block_resource_precompute(
            self.number_of_nodes, self.node_lut, self.node_bram, self.node_latency
        )

        dynamic_programming(
            self.num_devices,
            self.number_of_nodes,
            lut_usage,
            bram_usage,
            self.devices,
            self.output_sizes,
            dp,
            execution_latency,
            split_point
        )

        # Verify that the dp table has been updated with finite values
        self.assertNotEqual(dp[self.number_of_nodes][1], float('inf'))
        self.assertNotEqual(dp[self.number_of_nodes][2], float('inf'))

    def test_reconstruct_layers(self):
        dp, split_point = prepare_state_space(self.num_devices, self.number_of_nodes)
        execution_latency, lut_usage, bram_usage = sub_block_resource_precompute(
            self.number_of_nodes, self.node_lut, self.node_bram, self.node_latency
        )

        # Run dynamic programming to fill dp and split_point tables
        dynamic_programming(
            self.num_devices,
            self.number_of_nodes,
            lut_usage,
            bram_usage,
            self.devices,
            self.output_sizes,
            dp,
            execution_latency,
            split_point
        )

        # Assume optimal_devices is set after finding minimum latency
        optimal_devices = 2
        assignments, devices_used = reconstruct_layers(self.number_of_nodes, optimal_devices, split_point)

        # Check that assignments cover all layers
        covered_layers = set()
        for start, end in assignments:
            for layer in range(start, end + 1):
                covered_layers.add(layer)

        self.assertEqual(len(covered_layers), self.number_of_nodes)

    def test_optimal_configuration(self):
        min_latency, optimal_devices, optimal_assignment, dp, split_point, lut_usage, bram_usage = optimal_configuration(
            self.model_folded_path,
            self.devices,
            self.com_frequency,
            self.target_clk_ns
        )

        # Verify that all nodes are covered in the assignment
        covered_layers = set()
        for start, end in optimal_assignment:
            for layer in range(start, end + 1):
                covered_layers.add(layer)

        self.assertEqual(len(covered_layers), self.number_of_nodes)
        self.assertLessEqual(optimal_devices, self.num_devices)
        self.assertTrue(min_latency >= 0)  # Latency should be non-negative

# Run all tests
#if __name__ == "__main__":
#    unittest.main()
unittest.main()

In [None]:
node_latency = [30720, 81960, 324000, 142920, 282240, 9800, 26760, 207360, 37440, 288000, 1250, 3840, 207360, 1440,184320 ,327680, 327680, 5120, 20]  
node_lut = [16.0, 396.0, 2746.0, 428.0, 7410.0, 0, 428.0, 3855.0, 428.0, 3872.0, 0, 428.0, 1193.0, 428.0, 524.0,  336.0,363.0, 335.0, 0 ]  
node_bram = [0.0, 0.0, 2.0, 0.0, 29.0, 0.0, 0.0, 15.0, 0.0, 15.0, 0.0, 0.0, 24.0, 0.0, 36.0,  8.0, 16.0, 1.0, 0 ] 

In [None]:
def optimal_latency(number_of_nodes, frequency, max_lut, max_bram, num_devices):
    # Precompute execution times and resource usage for each sub-block of layers
    execution_latency = [[0] * number_of_nodes for _ in range(number_of_nodes)]
    lut_usage = [[0] * number_of_nodes for _ in range(number_of_nodes)]
    bram_usage = [[0] * number_of_nodes for _ in range(number_of_nodes)] 

    for start in range(number_of_nodes):
        lut_sum, bram_sum, exec_sum = 0, 0, 0
        for end in range(start, number_of_nodes):            
            lut_sum += node_lut[end]
            bram_sum += node_bram[end]
            exec_sum += node_latency[end]

            execution_latency[start][end] = exec_sum
            lut_usage[start][end] = lut_sum
            bram_usage[start][end] = bram_sum

    # DP table to store the minimum latency up to each layer with a given number of devices
    dp = [[float('inf')] * (num_devices + 1) for _ in range(number_of_nodes + 1)]
    split_point = [[-1] * (num_devices + 1) for _ in range(number_of_nodes + 1)]
    

    # Initialization for 1 device (all layers on one device)
    dp[0][1] = 0

    # Dynamic programming to find minimum latency configuration
    for j in range(1, num_devices + 1):  # Device count
        for i in range(1, number_of_nodes + 1):  # End layer for this device
            for k in range(i):  # Split point
                # Resource constraints check
                if lut_usage[k][i - 1] <= max_lut and bram_usage[k][i - 1] <= max_bram:
                    # Calculate latency for the current split configuration
                    comm_latency = node_output_size[k - 1] * 0.01 if k > 0 else 0
                    current_latency = dp[k][j - 1] + execution_latency[k][i - 1] + comm_latency

                    # Update dp table if current latency is lower
                    if current_latency < dp[i][j]:
                        dp[i][j] = current_latency
                        split_point[i][j] = k




    # Find the minimum latency for the full model on the available devices
    min_latency = min(dp[number_of_nodes][j] for j in range(1, num_devices + 1))
    optimal_devices = min(range(0, num_devices + 1), key=lambda j: dp[number_of_nodes][j])

    # Reconstruct the optimal device assignment from split_point table
    def reconstruct_layers():
        assignments = []
        i, j = number_of_nodes, optimal_devices
        while j > 1:
            start_layer = split_point[i][j] #+ 1
            assignments.append((start_layer, i - 1))
            i, j = start_layer, j - 1
        return assignments[::-1]

    # Output results
    optimal_assignment = reconstruct_layers()
    optimal_devices -= 1
    return {
        "min_latency": min_latency,
        "optimal_devices": optimal_devices,
        "assignments": optimal_assignment
    }

In [None]:
def optimal_latency(model_folded_path, devices, com_frequency, target_clk_ns=10):
    model = ModelWrapper(model_folded_path)
    number_of_nodes = len(model.graph.node)
    num_devices = len(devices)
#################################################### gathering information of each node. report from finn
    node_latency = []  
    node_lut = []
    node_bram = []
    output_shapes = []
    output_sizes = []
    output_dtypes = []

    for n in range(number_of_nodes):
        #TODO: generate submodules
        start_node = n
        end_node = start_node + 1
        path_submodel, output_shape, output_dtype = get_one_layer(model_folded_path, start_node, end_node)
        estimate_layer_resources, estimate_network_performance = generateReport(ModelWrapper(path_submodel), target_clk_ns)
        output_element = 1
        for size in output_shape:
            output_element = output_element*size
        node_latency.append(estimate_network_performance["estimated_latency_ns"])
        node_lut.append(estimate_layer_resources["total"]["LUT"])
        node_bram.append(estimate_layer_resources["total"]["BRAM_18K"])
        output_shapes.append(output_shape)
        output_sizes.append(output_element)
        output_dtypes.append(str(output_dtype))
    print(node_lut)
    
#######################################################

    # Precompute execution times and resource usage for each sub-block of layers
    execution_latency = [[0] * number_of_nodes for _ in range(number_of_nodes)]
    lut_usage = [[0] * number_of_nodes for _ in range(number_of_nodes)]
    bram_usage = [[0] * number_of_nodes for _ in range(number_of_nodes)] 

    for start in range(number_of_nodes):
        lut_sum, bram_sum, exec_sum = 0, 0, 0
        for end in range(start, number_of_nodes):            
            lut_sum += node_lut[end]
            bram_sum += node_bram[end]
            exec_sum += node_latency[end]

            execution_latency[start][end] = exec_sum
            lut_usage[start][end] = lut_sum
            bram_usage[start][end] = bram_sum

    # DP table to store the minimum latency up to each layer with a given number of devices
    dp = [[float('inf')] * (num_devices + 1) for _ in range(number_of_nodes + 1)]
    split_point = [[-1] * (num_devices + 1) for _ in range(number_of_nodes + 1)]
    
    # Initialization for 1 device (all layers on one device)
    dp[0][1] = 0

    # Dynamic programming to find minimum latency configuration
    for j in range(1, num_devices + 1):  # Device count
        for i in range(1, number_of_nodes + 1):  # End layer for this device
            for k in range(i):  # Split point
                # Resource constraints check
                print("\ndevices[j-1][\"maxLUT\"]\t" + str(devices[j-1]["maxLUT"]))
                print("lut_usage[k][i - 1]\t" + str(lut_usage[k][i - 1]))
                print("devices[j-1][\"maxBRAM\"]\t" + str(devices[j-1]["maxBRAM"]))
                print("bram_usage[k][i - 1]\t" + str(bram_usage[k][i - 1]))
                
                if lut_usage[k][i - 1] <= devices[j-1]["maxLUT"] and bram_usage[k][i - 1] <= devices[j-1]["maxBRAM"]:
                    # Calculate latency for the current split configuration
                    comm_latency = output_sizes[k - 1] * 0.01 if k > 0 else 0
                    current_latency = dp[k][j - 1] + execution_latency[k][i - 1] + comm_latency
                    print("comm_latency:\t" + str(comm_latency))
                    # Update dp table if current latency is lower
                    if current_latency < dp[i][j]:
                        dp[i][j] = current_latency
                        split_point[i][j] = k


    # Find the minimum latency for the full model on the available devices
    min_latency = min(dp[number_of_nodes][j] for j in range(1, num_devices + 1))
    optimal_devices = min(range(0, num_devices + 1), key=lambda j: dp[number_of_nodes][j])

    # Reconstruct the optimal device assignment from split_point table
    def reconstruct_layers():
        assignments = []
        i, j = number_of_nodes, optimal_devices
        while j > 1:
            start_layer = split_point[i][j] #+ 1
            assignments.append((start_layer, i - 1))
            i, j = start_layer, j - 1
        return assignments[::-1]

    # Output results
    optimal_assignment = reconstruct_layers()
    optimal_devices -= 1
    return {
        "min_latency": min_latency,
        "optimal_devices": optimal_devices,
        "assignments": optimal_assignment
    }

In [None]:

    layer_latency = [30720, 8196.0, 324000, 142920, 282240, 9800, 26760, 207360, 37440, 288000, 1250, 3840, 207360, 1440,184320 ,327680, 327680, 5120, 20]  # Latency for each layer
    layer_resources = [16.0, 396.0, 2746.0, 428.0, 7410.0, 0, 428.0, 3855.0, 428.0, 3872.0, 0, 428.0, 1193.0, 428.0, 524.0,  336.0,363.0, 335.0, 0 ]  # Resource usage for each layer
    

In [None]:
def find_optimal_partition_dp(model_folded_path, maxNumberOfDevices, devices, number_of_nodes):
    dp = [[float('inf')] * (maxNumberOfDevices + 1) for _ in range(number_of_nodes + 1)]
    dp[0][0] = 0  # Base case: zero layers has zero latency
    # To store the path to backtrack the optimal configuration
    prev = [[-1] * (maxNumberOfDevices + 1) for _ in range(number_of_nodes + 1)]
    
    # Step 2: Fill DP Table
    for j in range(1, maxNumberOfDevices + 1):  # Device count
        for i in range(1, number_of_nodes + 1):  # Layer count
            
            current_resource_usage = 0
            current_latency = 0
            
            for k in range(i):  # Starting layer of the current block
                # Step 3: Quantize and estimate latency for block from k to i

                start_node = k
                end_node = i
                block_path, last_output_shape, last_output_dtype = get_one_layer(model_folded_path, start_node, end_node)
                block = ModelWrapper(block_path)
                estimate_layer_resources, estimate_network_performance = generateReport(block, 10)
                
                # Resource constraints
                if (float(estimate_layer_resources["total"]["LUT"]) <= float(devices[j-1]["maxLUT"]) and
                    float(estimate_layer_resources["total"]["BRAM_18K"]) <= float(devices[j-1]["maxBRAM"])):

                    # Communication cost
                    communication_cost = estimate_communication_cost(last_output_shape, last_output_dtype, 100, communication_protocol)
                    
                    # Total latency
                    print(str(dp[k][j-1]) + "\t" + str(float(estimate_network_performance["estimated_latency_ns"])) + "\t" + str(communication_cost))
                    total_latency = dp[k][j-1] + float(estimate_network_performance["estimated_latency_ns"]) + communication_cost
                    
                    # Update DP table if this configuration is better
                    print(total_latency)
                    if total_latency < dp[i][j]:
                        dp[i][j] = total_latency
                        prev[i][j] = (k, estimate_layer_resources, estimate_network_performance, communication_cost)
                        #prev[i][j] = (k, communication_cost)
    
    # Step 4: Backtrack to find optimal configuration
    optimal_configuration = []
    i, j = number_of_nodes, maxNumberOfDevices
    while i > 0 and j > 0:
        if prev[i][j] is not None:
            k, estimate_layer_resources, estimate_network_performance, communication_cost = prev[i][j]
            #k, communication_cost = prev[i][j]
            optimal_configuration.append({
                "start_layer": k,
                "end_layer": i,
                "estimate_layer_resources": str(estimate_layer_resources),
                "estimate_network_performance": str(estimate_network_performance),
                "communication_cost": communication_cost
            })
        i, j = k, j - 1

    optimal_configuration.reverse()  # Order the blocks from first to last
    min_latency = dp[number_of_nodes][maxNumberOfDevices]
    print(min_latency)
    print(optimal_configuration)
    
    return (min_latency, optimal_configuration)

In [None]:
def find_optimal_partition_dp(model_folded_path, maxNumberOfDevices, devices, number_of_nodes):
    
    dp = [[float('inf')] * (number_of_nodes + 1) for _ in range(maxNumberOfDevices + 1)]
    dp[0][0] = 0  # Base case: zero layers has zero latency
    # To store the path to backtrack the optimal configuration
    prev = [[-1] * (number_of_nodes + 1) for _ in range(maxNumberOfDevices + 1)]
    #print(dp)

    layer_latency = [30720, 8196.0, 324000, 142920, 282240, 9800, 26760, 207360, 37440, 288000, 1250, 3840, 207360, 1440,184320 ,327680, 327680, 5120, 20]  # Latency for each layer
    layer_resources = [16.0, 396.0, 2746.0, 428.0, 7410.0, 0, 428.0, 3855.0, 428.0, 3872.0, 0, 428.0, 1193.0, 428.0, 524.0,  336.0,363.0, 335.0, 0 ]  # Resource usage for each layer
    
    # Step 2: Fill DP Table
    for i in range(1, maxNumberOfDevices + 1):  # Device count
        
        LUT_LIMIT = 240000#float(estimate_layer_resources["total"]["LUT"])
                          
        for j in range(1, number_of_nodes + 1):  # Layer count
            
            current_resource_usage = 0
            current_latency = 0

            for k in range(j, 0, -1):  # k is the start layer of the block
                current_resource_usage += layer_resources[k-1]
                current_latency += layer_latency[k-1]
    
                # Ensure resource constraints are met
                if current_resource_usage > LUT_LIMIT:
                    break  # Stop if current block exceeds resource limits
    
                # Update dp[i][j] if it reduces the minimum latency
                #print("dp[1][3]:" + str(dp[1][3]))
                #print("i:" + str(i)+ "\tj:"  + str(j) + "\tk:"  + str(k))
                if dp[i][j] > dp[i-1][k-1] + current_latency:
                    dp[i][j] = dp[i-1][k-1] + current_latency
                    prev[i][j] = k  # Track where the current block starts

    if dp[maxNumberOfDevices][number_of_nodes] == float('inf'):
        print("No feasible solution found.")
    else:
        print("Minimum latency:", dp[maxNumberOfDevices][number_of_nodes])
    
        # Backtrack to find optimal partitioning
        partitions = []
        remaining_layers = number_of_nodes
        fpgas_used = maxNumberOfDevices
        while remaining_layers > 0 and fpgas_used > 0:
            start_layer = prev[fpgas_used][remaining_layers]
            partitions.append((start_layer, remaining_layers))
            remaining_layers = start_layer - 1
            fpgas_used -= 1
    
        # Reverse to print in order
        partitions.reverse()
        print("Optimal partitions (start_layer, end_layer):", partitions)


In [None]:
added_latency = dp[i-1][k-1] + current_latency

In [None]:
INT8 = 8
BINARY = 1
UINT32 = 32
def find_optimal_partition_dp(model_folded_path, maxNumberOfDevices, devices, number_of_nodes, com_frequency):

    protocol = 1
    
    dp = [[float('inf')] * (number_of_nodes + 1) for _ in range(maxNumberOfDevices + 1)]
    dp[0][0] = 0  # Base case: zero layers has zero latency
    
    # To store the path to backtrack the optimal configuration
    prev = [[-1] * (number_of_nodes + 1) for _ in range(maxNumberOfDevices + 1)]

    node_latency = [30720, 8196.0, 324000, 142920, 282240, 9800, 26760, 207360, 37440, 288000, 1250, 3840, 207360, 1440,184320 ,327680, 327680, 5120, 20]  
    node_lut = [16.0, 396.0, 2746.0, 428.0, 7410.0, 0, 428.0, 3855.0, 428.0, 3872.0, 0, 428.0, 1193.0, 428.0, 524.0,  336.0,363.0, 335.0, 0 ]  
    last_output_shape = [[1, 32, 32, 3],[1, 30, 30, 27],[1, 30, 30, 64],[1, 28, 28, 576],[1, 28, 28, 64],[1, 14, 14, 64],[1, 12, 12, 576],[1, 12, 12, 128],[1, 10, 10, 1152],[1, 10, 10, 128],[1, 5, 5, 128],[1, 3, 3, 1152],[1, 3, 3, 256],[1, 1, 1, 2304],[1, 1, 1, 256],[1, 512],[1, 512],[1, 2],[1, 1]]
    last_output_dtype = [INT8,INT8,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,UINT32,BINARY]

    
    # Fill DP Table
    for i in range(1, maxNumberOfDevices + 1):  # Device count
        
        LUT_MAX = 22000 
        BRAM_MAX = 0
                          
        for j in range(1, number_of_nodes + 1):  # Layer count
            current_lut_usage = 0
            current_latency = 0

            
            COMM_LATENCY = estimate_communication_cost(last_output_shape[j-1], last_output_dtype[j-1], com_frequency, protocol)
            
            for k in range(j, 0, -1):  # k is the start layer of the block
                current_lut_usage += node_lut[k-1]
                current_latency += node_latency[k-1]
    
                # Ensure resource constraints are met
                # Stop if current block exceeds resource limits
                if current_lut_usage > LUT_MAX:
                    break  
    
                # Calculate the cost of using this configuration
                added_latency = dp[i-1][k-1] + current_latency

                # Add communication latency if moving to a new FPGA
                # TODO: make the communication latency as variable
                if i > 1:  
                    added_latency += COMM_LATENCY
                    #print(COMM_LATENCY)
                    
                # Update dp[i][j] if it reduces the minimum latency
                if dp[i][j] > added_latency:
                    dp[i][j] = added_latency
                    prev[i][j] = k  # Track where the current block starts


    # Find the minimum latency across all configurations using up to maxNumberOfDevices Devices
    min_latency = float('inf')
    optimal_device_count = 0
    for i in range(1, maxNumberOfDevices + 1):
        if dp[i][number_of_nodes] < min_latency:
            min_latency = dp[i][number_of_nodes]
            optimal_device_count = i
    
    # Output the result
    if min_latency == float('inf'):
        print("No feasible solution found.")
    else:
        print(f"Minimum latency: {min_latency} using {optimal_device_count} FPGA(s)")
    
        # Backtrack to find the optimal partitioning for the minimum FPGA count
        partitions = []
        remaining_nodes = number_of_nodes
        device_used = optimal_device_count
        while remaining_nodes > 0 and device_used > 0:
            start_node = prev[device_used][remaining_nodes]
            partitions.append((start_node, remaining_nodes))
            remaining_nodes = start_node - 1
            device_used -= 1
    
        # Reverse to print in order
        partitions.reverse()
        print("Optimal partitions (start_layer, end_layer):", partitions)

In [None]:

    for i in range(1, maxNumberOfDevices + 1):  # Device count
        
        LUT_LIMIT = 240000#float(estimate_layer_resources["total"]["LUT"])
                          
        for j in range(1, number_of_nodes + 1):  # Layer count
            
            current_resource_usage = 0
            current_latency = 0

            for k in range(j, 0, -1):  # k is the start layer of the block
                current_resource_usage += layer_resources[k-1]
                current_latency += layer_latency[k-1]
    
                # Ensure resource constraints are met
                if current_resource_usage > LUT_LIMIT:
                    break  # Stop if current block exceeds resource limits
    
                # Update dp[i][j] if it reduces the minimum latency
                #print("dp[1][3]:" + str(dp[1][3]))
                #print("i:" + str(i)+ "\tj:"  + str(j) + "\tk:"  + str(k))
                if dp[i][j] > dp[i-1][k-1] + current_latency:
                    dp[i][j] = dp[i-1][k-1] + current_latency
                    prev[i][j] = k  # Track where the current block starts

    if dp[maxNumberOfDevices][number_of_nodes] == float('inf'):
        print("No feasible solution found.")
    else:
        print("Minimum latency:", dp[maxNumberOfDevices][number_of_nodes])
    
        # Backtrack to find optimal partitioning
        partitions = []
        remaining_layers = number_of_nodes
        fpgas_used = maxNumberOfDevices
        while remaining_layers > 0 and fpgas_used > 0:
            start_layer = prev[fpgas_used][remaining_layers]
            partitions.append((start_layer, remaining_layers))
            remaining_layers = start_layer - 1
            fpgas_used -= 1
    
        # Reverse to print in order
        partitions.reverse()
        print("Optimal partitions (start_layer, end_layer):", partitions)

In [None]:
INT8 = 8
BINARY = 1
UINT32 = 32
def find_optimal_partition_dp(model_folded_path, maxNumberOfDevices, devices, number_of_nodes, com_frequency):

    protocol = 1
    
    dp = [[float('inf')] * (number_of_nodes + 1) for _ in range(maxNumberOfDevices + 1)]
    dp[0][0] = 0  # Base case: zero layers has zero latency
    
    # To store the path to backtrack the optimal configuration
    prev = [[-1] * (number_of_nodes + 1) for _ in range(maxNumberOfDevices + 1)]

    node_latency = [30720, 8196.0, 324000, 142920, 282240, 9800, 26760, 207360, 37440, 288000, 1250, 3840, 207360, 1440,184320 ,327680, 327680, 5120, 20]  
    node_lut = [16.0, 396.0, 2746.0, 428.0, 7410.0, 0, 428.0, 3855.0, 428.0, 3872.0, 0, 428.0, 1193.0, 428.0, 524.0,  336.0,363.0, 335.0, 0 ]  
    last_output_shape = [[1, 32, 32, 3],[1, 30, 30, 27],[1, 30, 30, 64],[1, 28, 28, 576],[1, 28, 28, 64],[1, 14, 14, 64],[1, 12, 12, 576],[1, 12, 12, 128],[1, 10, 10, 1152],[1, 10, 10, 128],[1, 5, 5, 128],[1, 3, 3, 1152],[1, 3, 3, 256],[1, 1, 1, 2304],[1, 1, 1, 256],[1, 512],[1, 512],[1, 2],[1, 1]]
    last_output_dtype = [INT8,INT8,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,BINARY,UINT32,BINARY]

    node_output_size = []
    for n in last_output_shape:
        elements = 0
        for i in n:
            elements *=elements 
        node_output_size.append(elements)
    
    # Fill DP Table
    for i in range(1, maxNumberOfDevices + 1):  # Device count
        
        LUT_MAX = 23000 
        BRAM_MAX = 0
                          
        for j in range(1, number_of_nodes + 1):  # Layer count
            current_lut_usage = 0
            current_latency = 0

            
            COMM_LATENCY = estimate_communication_cost(last_output_shape[j-1], last_output_dtype[j-1], com_frequency, protocol)
            
            for k in range(j, 0, -1):  # k is the start layer of the block
                current_lut_usage += node_lut[k-1]
                current_latency += node_latency[k-1]
    
                # Ensure resource constraints are met
                # Stop if current block exceeds resource limits
                if current_lut_usage > LUT_MAX:
                    break  
    
                # Calculate the cost of using this configuration
                added_latency = dp[i-1][k-1] + current_latency

                # Add communication latency if moving to a new FPGA
                # TODO: make the communication latency as variable
                if i > 1:  
                    added_latency += COMM_LATENCY
                    #print(COMM_LATENCY)
                    
                # Update dp[i][j] if it reduces the minimum latency
                if dp[i][j] > added_latency:
                    dp[i][j] = added_latency
                    prev[i][j] = k  # Track where the current block starts


    # Find the minimum latency across all configurations using up to maxNumberOfDevices Devices
    min_latency = float('inf')
    optimal_device_count = 0
    for i in range(1, maxNumberOfDevices + 1):
        if dp[i][number_of_nodes] < min_latency:
            min_latency = dp[i][number_of_nodes]
            optimal_device_count = i
    
    # Output the result
    if min_latency == float('inf'):
        print("No feasible solution found.")
    else:
        print(f"Minimum latency: {min_latency} using {optimal_device_count} FPGA(s)")
    
        # Backtrack to find the optimal partitioning for the minimum FPGA count
        partitions = []
        remaining_nodes = number_of_nodes
        device_used = optimal_device_count
        while remaining_nodes > 0 and device_used > 0:
            start_node = prev[device_used][remaining_nodes]
            partitions.append((start_node, remaining_nodes))
            remaining_nodes = start_node - 1
            device_used -= 1
    
        # Reverse to print in order
        partitions.reverse()
        print("Optimal partitions (start_layer, end_layer):", partitions)

    initial_partitions = partitions
    # Phase 2: Communication cost optimization

    optimized_partitions = []
    for i, (start, end) in enumerate(initial_partitions):
        min_block_latency = float('inf')
        best_split = start
    
        # Explore alternative split points within each partition to minimize communication cost
        for split_point in range(start, end):
            block_latency = 0
            for node in range(start, split_point + 1):
                block_latency += node_latency[node - 1]
    
            # Compute communication cost for this partition point
            comm_latency = 0#estimate_communication_cost(last_output_shape[split_point-1], last_output_dtype[split_point-1], com_frequency, protocol) #node_output_size[split_point - 1] * COMM_COST_MULTIPLIER if i > 0 else 0
            total_latency = block_latency + comm_latency
    
            # Find the split point with the minimum latency for this partition
            if total_latency < min_block_latency:
                min_block_latency = total_latency
                best_split = split_point
    
        # Append optimized split point for this FPGA
        optimized_partitions.append((best_split, end))
    
    print("Optimized partitions (start_layer, end_layer):", optimized_partitions)
    
    # Output the final results
    final_latency = sum(dp[optimal_device_count][partition[1]] for partition in optimized_partitions)
    print(f"Final minimum latency after optimization: {final_latency}")
    print("Optimized partitions:", optimized_partitions)

In [None]:

def find_optimal_partition_dp(model_folded_path, maxNumberOfDevices, devices, number_of_nodes, com_frequency):

    
    dp = [[float('inf')] * (number_of_nodes + 1) for _ in range(maxNumberOfDevices + 1)]
    dp[0][0] = 0  # Base case: zero layers has zero latency
    
    # To store the path to backtrack the optimal configuration
    prev = [[-1] * (number_of_nodes + 1) for _ in range(maxNumberOfDevices + 1)]

    

    node_output_size = []
    for n in last_output_shape:
        elements = 0
        for i in n:
            elements *=elements 
        node_output_size.append(elements)
    
    # Fill DP Table
    for i in range(1, maxNumberOfDevices + 1):  # Device count
        
        LUT_MAX = 23000 
        BRAM_MAX = 0
                          
        for j in range(1, number_of_nodes + 1):  # Layer count
            current_lut_usage = 0
            current_latency = 0

            
            COMM_LATENCY = estimate_communication_cost(last_output_shape[j-1], last_output_dtype[j-1], com_frequency, protocol)
            
            for k in range(j, 0, -1):  # k is the start layer of the block
                current_lut_usage += node_lut[k-1]
                current_latency += node_latency[k-1]
    
                # Ensure resource constraints are met
                # Stop if current block exceeds resource limits
                if current_lut_usage > LUT_MAX:
                    break  
    
                # Calculate the cost of using this configuration
                added_latency = dp[i-1][k-1] + current_latency

                # Add communication latency if moving to a new FPGA
                # TODO: make the communication latency as variable
                if i > 1:  
                    added_latency += COMM_LATENCY
                    #print(COMM_LATENCY)
                    
                # Update dp[i][j] if it reduces the minimum latency
                if dp[i][j] > added_latency:
                    dp[i][j] = added_latency
                    prev[i][j] = k  # Track where the current block starts


    # Find the minimum latency across all configurations using up to maxNumberOfDevices Devices
    min_latency = float('inf')
    optimal_device_count = 0
    for i in range(1, maxNumberOfDevices + 1):
        if dp[i][number_of_nodes] < min_latency:
            min_latency = dp[i][number_of_nodes]
            optimal_device_count = i
    
    # Output the result
    if min_latency == float('inf'):
        print("No feasible solution found.")
    else:
        print(f"Minimum latency: {min_latency} using {optimal_device_count} FPGA(s)")
    
        # Backtrack to find the optimal partitioning for the minimum FPGA count
        partitions = []
        remaining_nodes = number_of_nodes
        device_used = optimal_device_count
        while remaining_nodes > 0 and device_used > 0:
            start_node = prev[device_used][remaining_nodes]
            partitions.append((start_node, remaining_nodes))
            remaining_nodes = start_node - 1
            device_used -= 1
    
        # Reverse to print in order
        partitions.reverse()
        print("Optimal partitions (start_layer, end_layer):", partitions)

    initial_partitions = partitions
    # Phase 2: Communication cost optimization

    optimized_partitions = []
    for i, (start, end) in enumerate(initial_partitions):
        min_block_latency = float('inf')
        best_split = start
    
        # Explore alternative split points within each partition to minimize communication cost
        for split_point in range(start, end):
            block_latency = 0
            for node in range(start, split_point + 1):
                block_latency += node_latency[node - 1]
    
            # Compute communication cost for this partition point
            comm_latency = 0#estimate_communication_cost(last_output_shape[split_point-1], last_output_dtype[split_point-1], com_frequency, protocol) #node_output_size[split_point - 1] * COMM_COST_MULTIPLIER if i > 0 else 0
            total_latency = block_latency + comm_latency
    
            # Find the split point with the minimum latency for this partition
            if total_latency < min_block_latency:
                min_block_latency = total_latency
                best_split = split_point
    
        # Append optimized split point for this FPGA
        optimized_partitions.append((best_split, end))
    
    print("Optimized partitions (start_layer, end_layer):", optimized_partitions)
    
    # Output the final results
    final_latency = sum(dp[optimal_device_count][partition[1]] for partition in optimized_partitions)
    print(f"Final minimum latency after optimization: {final_latency}")
    print("Optimized partitions:", optimized_partitions)

In [None]:
for i in range(19):
    print(estimate_communication_cost(last_output_shape[i], last_output_dtype[i],100000000, 1))

In [None]:
import numpy as np
def optimal_partitioning(model_folded_path, maxNumberOfDevices, devices, number_of_nodes, com_frequency, com_protocol):
    # Initialize DP table: dp[i][j] where i is layer index, j is device index
    dp = np.full((number_of_nodes + 1, maxNumberOfDevices + 1), float('inf'))
    path = [[[] for _ in range(maxNumberOfDevices + 1)] for _ in range(number_of_nodes + 1)]
    dp[0][0] = 0  # Base case: no layers and no devices has zero latency
    
    # Fill the DP table
    for i in range(1, number_of_nodes + 1):           # {1,...,18} end block
        for j in range(1, maxNumberOfDevices + 1):    # {1,...,3}
            for k in range(i):                        # satrt block
                
                # Calculate the latency and resource usage of block from layer k to i
                block_latency, block_LUT, block_BRAM, comm_cost = 0, 0, 0, 0
                
                for node in range(k,i):               # block report
                    block_latency += node_latency[k]
                    block_LUT += node_lut[k]
                    block_BRAM += 1
                    if k > 0:
                        comm_cost = estimate_communication_cost(last_output_shape[i-1], last_output_dtype[i-1], com_frequency, com_protocol)  #layers[k-1].output_size  # Communication cost at split point
 
                # Check constraints
                if block_LUT <= max_LUT and block_BRAM <= max_BRAM:
                    total_latency = dp[k][j-1] + block_latency + comm_cost
                    # Update DP if this split gives a lower latency
                    
                    if total_latency < dp[i][j]:
                        dp[i][j] = total_latency
                        path[i][j] = path[k][j-1] + [(k, i)]  # Record split point
    
    # Find the minimum latency configuration
    min_latency, optimal_devices = float('inf'), 0
    optimal_modules = []
    for j in range(1, maxNumberOfDevices + 1):
        if dp[number_of_nodes][j] < min_latency:
            min_latency = dp[number_of_nodes][j]
            optimal_devices = j
            optimal_modules = path[number_of_nodes][j]
    
    return optimal_devices, optimal_modules, min_latency

In [None]:
find_optimal_partition_dp(model_folded_path, maxNumberOfDevices, devices, number_of_nodes, 10000000)