In [1]:
# Basic Block Ordering

# No pointer, all of the pointers should be mapped to the indices starting from 0.
import copy

class BasicBlockNode:
  def __init__(self, BB_index=None):
    self.prev = None
    self.next = None
    self.BB_index = BB_index

class Chain:
  def __init__(self, head=None, tail=None):
    self.head = head
    self.tail = tail
  
  def Merge(self, s_to_merge):
    if s_to_merge is None:
      return self
    self.tail.next = s_to_merge.head
    s_to_merge.head.prev = self.tail
    self.tail = s_to_merge.tail
    self.tail.next = None
    return self
  
  def print_chain(self):
    temp = self.head
    while temp:
      print(temp.BB_index, end=',')
      temp = temp.next
    print('')

class Graph:
  def __init__(self, entry_point=None, adjacent_matrix=None, BB_size=None):
    # entry_point: int, index of the entry BB
    # adjacent_matrix: 2d array, adjacent_matrix[i][j] == 0 means no branch from
    #                  BB i to BB j, adjacent_matrix[i][j] > 0 means there is a branch
    #                  from BB i to BB j, and the value refers to w(i, j).
    # BB_size: list, containing the size of BBs and BB_size[i] is the size of BB i.
    self.entry_point = entry_point
    self.adjacent_matrix = adjacent_matrix
    self.BB_size = BB_size

class BBOrdering:
  def __init__(self, graph=None, entry_point=None):
    self.graph = graph
    self.entry_point = entry_point
  
  def ReorderingBasicBlocks(self):
    chains_set = set()
    # Initial chain creation
    # Note that i's are the actual BB indices.
    for i in range(len(self.graph.adjacent_matrix)):
        singleton_BB = BasicBlockNode(i)
        chains_set.add(Chain(head=singleton_BB, tail=singleton_BB))
    # Chain Merging
    gain = dict()
    # Best among six possibilities
    merged_dict = dict()
    while len(chains_set) > 1:
      for chain1 in chains_set:
        for chain2 in chains_set:
          if chain1 == chain2:
            continue
          if (chain1, chain2) not in gain:
            gain[(chain1, chain2)], merged_dict[(chain1, chain2)] = self.ComputeMergeGain(chain1, chain2)
      max_gain = 0
      chain1_max = None
      chain2_max = None
      for (chain1, chain2), gain_val in gain.items():
        if chain1 in chains_set and chain2 in chains_set and gain_val > max_gain:
          max_gain = gain_val
          chain1_max = chain1
          chain2_max = chain2
      chains_set.add(merged_dict[(chain1_max, chain2_max)])
      chains_set.remove(chain1_max)
      chains_set.remove(chain2_max)

    return list(chains_set)[0]


  def ComputeMergeGain(self, src, dst):
    # Try all ways to split chain src
    # cut_BB means the last BB node in s1
    cut_BB = src.head
    s1_start = src.head
    s1_end = cut_BB
    s2_start = cut_BB.next
    if cut_BB.next == None:
        s2_end = None
    else:
        s2_end = src.tail
    gain_max = float('-inf')
    merged_chain_max = None
    while cut_BB != None:
      s1_end = cut_BB
      s2_start = cut_BB.next
      s1 = Chain(s1_start, s1_end)
      if cut_BB.next == None:
        s2 = None
      else:
        s2 = Chain(s2_start, s2_end)

    #   gain_max = 0
    #   merged_chain_max = None
      curr = 0
      curr_merge = None
      # ExtTSP(s1, s2, dst)
      if dst.head.BB_index != self.entry_point:
        s_1 = copy.deepcopy(s1)
        s_2 = copy.deepcopy(s2)
        # s_1 = s1.copy()
        # s_2 = s2.copy()
        # dst_ = dst.copy()
        dst_ = copy.deepcopy(dst)
        curr_merge = s_1.Merge(s_2)
        curr_merge = curr_merge.Merge(dst_)
        curr = self.ExtTSP(curr_merge)
        if curr > gain_max:
          gain_max = curr
          merged_chain_max = curr_merge

      # ExtTSP(s1, dst, s2)
      if dst.head.BB_index != self.entry_point:
        s_1 = copy.deepcopy(s1)
        s_2 = copy.deepcopy(s2)
        # s_1 = s1.copy()
        # s_2 = s2.copy()
        # dst_ = dst.copy()
        dst_ = copy.deepcopy(dst)
        curr_merge = s_1.Merge(dst_)
        curr_merge = curr_merge.Merge(s_2)
        curr = self.ExtTSP(curr_merge)
        if curr > gain_max:
          gain_max = curr
          merged_chain_max = curr_merge

      # ExtTSP(s2, s1, dst)
      if dst.head.BB_index != self.entry_point and s1_start.BB_index != self.entry_point:
        s_1 = copy.deepcopy(s1)
        s_2 = copy.deepcopy(s2)
        # s_1 = s1.copy()
        # s_2 = s2.copy()
        # dst_ = dst.copy()
        dst_ = copy.deepcopy(dst)
        if s_2 is not None:
          curr_merge = s_2.Merge(s_1)
        else:
          curr_merge = s_1
        curr_merge = curr_merge.Merge(dst_)
        curr = self.ExtTSP(curr_merge)
        if curr > gain_max:
          gain_max = curr
          merged_chain_max = curr_merge

      # ExtTSP(s2, dst, s1)
      if dst.head.BB_index != self.entry_point and s1_start.BB_index != self.entry_point:
        s_1 = copy.deepcopy(s1)
        s_2 = copy.deepcopy(s2)
        # s_1 = s1.copy()
        # s_2 = s2.copy()
        # dst_ = dst.copy()
        dst_ = copy.deepcopy(dst)
        if s_2 is not None:
          curr_merge = s_2.Merge(dst_)
        else:
          curr_merge = dst_
        curr_merge = curr_merge.Merge(s_1)
        curr = self.ExtTSP(curr_merge)
        if curr > gain_max:
          gain_max = curr
          merged_chain_max = curr_merge

      # ExtTSP(dst, s1, s2)
      if s1_start.BB_index != self.entry_point:
        if not s2_start or s2_start.BB_index != self.entry_point:
          s_1 = copy.deepcopy(s1)
          s_2 = copy.deepcopy(s2)
          # s_1 = s1.copy()
          # s_2 = s2.copy()
          # dst_ = dst.copy()
          dst_ = copy.deepcopy(dst)
          curr_merge = dst_.Merge(s_1)
          curr_merge = curr_merge.Merge(s_2)
          curr = self.ExtTSP(curr_merge)
          if curr > gain_max:
            gain_max = curr
            merged_chain_max = curr_merge

      # ExtTSP(dst, s2, s1)
      if s1_start.BB_index != self.entry_point:
        if not s2_start or s2_start.BB_index != self.entry_point:
          s_1 = copy.deepcopy(s1)
          s_2 = copy.deepcopy(s2)
          # s_1 = s1.copy()
          # s_2 = s2.copy()
          # dst_ = dst.copy()
          dst_ = copy.deepcopy(dst)
          curr_merge = dst_.Merge(s_2)
          curr_merge = curr_merge.Merge(s_1)
          curr = self.ExtTSP(curr_merge)
          if curr > gain_max:
            gain_max = curr
            merged_chain_max = curr_merge
      cut_BB = cut_BB.next
      # print('!!!')
      # print("new gain", gain_max - self.ExtTSP(src) - self.ExtTSP(dst))
      # print("new order")
      # merged_chain_max.print_chain()
      # print('!!!')
    return gain_max - self.ExtTSP(src, print_or_not=False) - self.ExtTSP(dst, print_or_not=False), merged_chain_max
  
  def ExtTSP(self, subchains, print_or_not=True):
    # sudchains: a merged chain, consisted of three ordered suchains
    ##### dst_indx: int, may be 0, 1 or 2, indicating which subchain is the destination.
    curr = subchains.head
    BB_pos_dict = dict()
    pos_BB_list = []
    pos = 0
    while curr:
      BB_pos_dict[curr.BB_index] = pos
      pos_BB_list.append(curr.BB_index)
      # print("BB idx",curr.BB_index)
      pos += 1
      curr = curr.next

    score = 0

    # Note that i's are the positions in these ordered subchains.
    for i in range(len(pos_BB_list)):
      BB_idx = pos_BB_list[i]
      # Note that j's are the actual BB indices.
      for j in range(len(self.graph.adjacent_matrix)):
        if self.graph.adjacent_matrix[BB_idx][j] > 0 and j in BB_pos_dict:
          length_of_jump = 0
          # Forward branch
          if i < BB_pos_dict[j]:
            for k in range(i + 1, BB_pos_dict[j]):
              length_of_jump += self.graph.BB_size[pos_BB_list[k]]
          # Backward branch
          else:
            for k in range(BB_pos_dict[j], i + 1):
              length_of_jump += self.graph.BB_size[pos_BB_list[k]]
          if length_of_jump == 0:
            score += self.graph.adjacent_matrix[BB_idx][j] * 1
          elif 0 < length_of_jump <= 1024 and i < BB_pos_dict[j]:
            score += self.graph.adjacent_matrix[BB_idx][j] * 0.1 * (1 - length_of_jump / 1024)
          elif 0 < length_of_jump <= 640 and BB_pos_dict[j] < i:
            score += self.graph.adjacent_matrix[BB_idx][j] * 0.1 * (1 - length_of_jump / 640)
          else:
            score += 0
    if print_or_not:
      print("---------------")
      print('score: ', score)
      subchains.print_chain()
      print("---------------")
    return score


In [2]:
entry_point = 0
adjacent_matrix = [[0, 51, 49, 0],
                   [0, 0, 0, 51],
                   [0, 0, 0, 49],
                   [0, 0, 0, 0]]
BB_size = [250, 500, 300, 180]   

graph = Graph(entry_point, adjacent_matrix, BB_size)
ordering = BBOrdering(graph, entry_point)
new_ordering = ordering.ReorderingBasicBlocks()
new_ordering.print_chain()

---------------
score:  49
0,2,
---------------
---------------
score:  49
0,2,
---------------
---------------
score:  0
0,3,
---------------
---------------
score:  0
0,3,
---------------
---------------
score:  51
0,1,
---------------
---------------
score:  51
0,1,
---------------
---------------
score:  49
0,2,
---------------
---------------
score:  49
0,2,
---------------
---------------
score:  49
2,3,
---------------
---------------
score:  49
2,3,
---------------
---------------
score:  49
2,3,
---------------
---------------
score:  1.225
3,2,
---------------
---------------
score:  1.225
3,2,
---------------
---------------
score:  1.225
3,2,
---------------
---------------
score:  0
2,1,
---------------
---------------
score:  0
2,1,
---------------
---------------
score:  0
2,1,
---------------
---------------
score:  0
1,2,
---------------
---------------
score:  0
1,2,
---------------
---------------
score:  0
1,2,
---------------
---------------
score:  0
0,3,
--------